Skip to content

Commit e62d0fa

Browse files
authored
[GPU] Improve kv cache memory allocation efficiency (#25580)
### Details: - Fixed two issues - 1) KV cache was allocating redundant memory when it requires new memory - 2) At a new inference, KV cache was setting a padding value as the one used in the previous execution (last token for the previous generation), which caused memory usage inefficiency. - After fixing above issues, in some cases, memory is more frequently allocated because - 1) switching shape 1024 => 32 : happens reclaiming (previously due to the wrong padding, it is not reclaimed.) - 2) switching shape 32 => 1024 : new alloc needed at the first infer, but shape history is not tracked yet. So during 3 iteration, it is allocating new memory. - Additional fix to resolve above issues: - 1) For initial allocation of kv cache, enforce prealloc with custom prealloc count (known value of 128 + id%64) for sequence axis - 2) For reclaiming kv cache : use prealloc size as the required memory size Memalloc count with PR ![image](https://github.com/user-attachments/assets/c65b3335-c849-46f8-b9fe-140c3a0fbccb) ### Tickets: - 146930
1 parent 4dbe733 commit e62d0fa

File tree

3 files changed

+94
-34
lines changed

3 files changed

+94
-34
lines changed

src/plugins/intel_gpu/include/intel_gpu/runtime/shape_predictor.hpp

+21-15
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,31 @@ struct ShapePredictor {
3434
static_assert(_max_deque_size >= 2, "[GPU] Deque is supposed to contain at least 2 elements for prediction");
3535
}
3636

37-
38-
/// \brief Predicts the next possible shapes sizes based on history collected by previous
39-
/// predict_preallocation_shape() calls.
40-
/// This function works in two modes: by default it tries to predict shape for the next
41-
/// `_next_iters_preallocation_count` iterations, in case if per-iteration buffer size is less than
42-
/// `_max_per_iter_size` and difference between shapes is less than `_max_per_dim_diff`; the second
43-
/// operation mode is percentage preallocation - this mode can be configured with
44-
/// ov::intel_gpu::buffers_preallocation_ratio property, it increases buffer size by
45-
/// `_buffers_preallocation_ratio` value unconditionally.
46-
/// \param id Primitive id.
47-
/// \param layout Primitive's layout on current iteration.
48-
/// \param can_reuse_buffer Specifies if current memory buffer is enough to store data.
49-
/// \return The result of shape size prediction as std::pair<bool, ov::Shape>, where the first element
50-
/// says if shape is successfully predicted and can be preallocated, and the second element is ov::Shape itself.
37+
/// \brief Predicts the next possible shapes sizes based on history collected by previous
38+
/// predict_preallocation_shape() calls.
39+
/// This function works in two modes: by default it tries to predict shape for the next
40+
/// `_next_iters_preallocation_count` iterations, in case if per-iteration buffer size is less than
41+
/// `_max_per_iter_size` and difference between shapes is less than `_max_per_dim_diff`; the second
42+
/// operation mode is percentage preallocation - this mode can be configured with
43+
/// ov::intel_gpu::buffers_preallocation_ratio property, it increases buffer size by
44+
/// `_buffers_preallocation_ratio` value unconditionally.
45+
/// \param id Primitive id.
46+
/// \param layout Primitive's layout on current iteration.
47+
/// \param can_reuse_buffer Specifies if current memory buffer is enough to store data.
48+
/// \param out_idx output index of multiple outputs
49+
/// \param custom_next_iters_prealloc_couunt If it is specified, enforce prealloc size as the specified value
50+
/// \param custom_prealloc_dim If both custom_next_iters_prealloc_count and custom_prealloc_dim are specified,
51+
/// increase custom_prealloc_dim with custom_next_iters_prealloc_count without checking shape history (e.g.,
52+
/// used for first inference of kv cache)
53+
/// \return The result of shape size prediction as std::pair<bool, ov::Shape>, where
54+
/// the first element says if shape is successfully predicted and can be preallocated, and the second
55+
/// element is ov::Shape itself.
5156
std::pair<bool, ov::Shape> predict_preallocation_shape(const std::string& id,
5257
const cldnn::layout& layout,
5358
bool can_reuse_buffer,
5459
const size_t out_idx = 0,
55-
int32_t next_iters_prealloc_count = -1);
60+
int32_t custom_next_iters_prealloc_count = -1,
61+
int32_t custom_prealloc_dim = -1);
5662

5763
bool can_preallocate(size_t desired_buffer_size);
5864

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+64-18
Original file line numberDiff line numberDiff line change
@@ -664,9 +664,32 @@ event::ptr primitive_inst::realloc_if_needed() {
664664
updated_layouts[0] = layout(current_buf_shape, updated_layouts[0].data_type, updated_layouts[0].format);
665665
}
666666

667+
int32_t tmp_prealloc_count = get_prealloc_iter_num();
668+
GPU_DEBUG_IF(debug_config->mem_preallocation_params.is_initialized) {
669+
// If debug config is set, repsect the config most
670+
tmp_prealloc_count = -1;
671+
}
672+
667673
// If we allocated too large memory, reclaim the memory.
668674
for (size_t i = 0; i < updated_layouts.size(); ++i) {
669-
if (updated_layouts[i].get_buffer_size().count() * 10 < _max_output_layout_count[i]) {
675+
bool reclaim = 0;
676+
size_t required_buffer_size = 0;
677+
if (_node->is_type<kv_cache>() && i == 0) {
678+
// Relax reclaiming condition for kv cache
679+
const auto& desc = _node->as<kv_cache>().get_primitive();
680+
auto prealloc_shape = updated_layouts[i].get_shape();
681+
const auto shape_rank = prealloc_shape.size();
682+
auto seq_axis =
683+
static_cast<int32_t>(desc->concat_axis >= 0 ? desc->concat_axis : shape_rank + desc->concat_axis);
684+
prealloc_shape[seq_axis] += tmp_prealloc_count;
685+
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
686+
} else {
687+
required_buffer_size = (updated_layouts[i].get_buffer_size().count());
688+
}
689+
if (required_buffer_size * 10 < _max_output_layout_count[i]) {
690+
reclaim = true;
691+
}
692+
if (reclaim) {
670693
GPU_DEBUG_TRACE_DETAIL << id() << ": Updated output[" << i << "] size " << updated_layouts[i].get_buffer_size().count()
671694
<< " is much smaller than current memory size! " << _max_output_layout_count[i]
672695
<< "Reset memory of output " << i << std::endl;
@@ -681,31 +704,51 @@ event::ptr primitive_inst::realloc_if_needed() {
681704
return ev;
682705
}
683706

684-
int32_t tmp_prealloc_count = get_prealloc_iter_num();
685-
GPU_DEBUG_IF(debug_config->mem_preallocation_params.is_initialized) {
686-
// If debug config is set, repsect the config most
687-
tmp_prealloc_count = -1;
688-
}
689-
690707
for (size_t i = 0; i < actual_layouts.size(); ++i) {
691708
bool can_reuse_buffer = (_outputs[i] && updated_layouts[i].get_buffer_size().count() <= _max_output_layout_count[i]);
692-
std::pair<bool, ov::Shape> prealloc_info =
693-
sp.predict_preallocation_shape(id(), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
709+
std::pair<bool, ov::Shape> prealloc_info;
710+
if (_node->is_type<kv_cache>() && i == 0) {
711+
const auto& desc = _node->as<kv_cache>().get_primitive();
712+
auto shape_rank = updated_layouts[i].get_shape().size();
713+
auto seq_axis =
714+
static_cast<int32_t>(desc->concat_axis >= 0 ? desc->concat_axis : shape_rank + desc->concat_axis);
715+
prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis);
716+
} else {
717+
prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
718+
}
694719
if (prealloc_info.first && sp.can_preallocate(ov::shape_size(prealloc_info.second) * (dt_sizes_in_B[i]))) {
695720
auto new_layout = updated_layouts[i];
696721
new_layout.set_partial_shape(prealloc_info.second);
697722
updated_params.output_layouts[i] = new_layout;
698723
}
699-
700724
if (updated_params.output_layouts[i].get_buffer_size().count() < updated_layouts[i].get_buffer_size().count()) {
701725
updated_params.output_layouts[i] = updated_layouts[i];
702726
}
703727

704728
if (can_reuse_buffer) {
705-
GPU_DEBUG_TRACE_DETAIL << id() << ": reuse previously allocated output buffer - "
729+
GPU_DEBUG_TRACE_DETAIL << id() << ": reuse previously allocated output buffer[" << i << "] - "
706730
<< actual_layouts[i].get_buffer_size().count() << "/" << _max_output_layout_count[i]
707731
<< std::endl;
708-
_outputs[i] = _network.get_engine().reinterpret_buffer(*_outputs[i], actual_layouts[i]);
732+
if (_node->is_type<kv_cache>() && (i == 0)) {
733+
// kv_cache has already assigned memory.
734+
// No need to reinterpret output memory but need to update padding
735+
const auto& desc = _node->as<kv_cache>().get_primitive();
736+
auto& present_layout = _impl_params->output_layouts[i];
737+
const auto present_layout_rank = present_layout.get_partial_shape().size();
738+
const auto sequence_axis =
739+
desc->concat_axis >= 0 ? desc->concat_axis : present_layout_rank + desc->concat_axis;
740+
741+
const auto sequence_axis_legacy = kv_cache_inst::get_sequence_axis_legacy(sequence_axis, present_layout_rank);
742+
auto max_pad = kv_cache_inst::get_max_pad(present_layout,
743+
_max_output_layout_count[i],
744+
sequence_axis_legacy,
745+
"present_layout");
746+
kv_cache_inst::update_pad(present_layout, max_pad, sequence_axis_legacy);
747+
GPU_DEBUG_TRACE_DETAIL << _impl_params->output_layouts[i].to_string() << std::endl;
748+
set_shape_change();
749+
} else {
750+
_outputs[i] = _network.get_engine().reinterpret_buffer(*_outputs[i], actual_layouts[i]);
751+
}
709752
// TODO: check need_reset_output_memory per output
710753
if (need_reset_output_memory() && !can_be_optimized()) {
711754
GPU_DEBUG_TRACE_DETAIL << id() << " : Need reset output memory considering user" << std::endl;
@@ -740,7 +783,7 @@ event::ptr primitive_inst::realloc_if_needed() {
740783

741784
// Set variable memory same as output memory
742785
if (_node->is_type<kv_cache>()) {
743-
auto desc = _node->as<kv_cache>().get_primitive();
786+
const auto& desc = _node->as<kv_cache>().get_primitive();
744787
auto& variable = get_network().get_variable(desc->variable_info.variable_id);
745788
auto present_layout = _impl_params->output_layouts[0];
746789
auto present_layout_rank = present_layout.get_partial_shape().size();
@@ -760,7 +803,7 @@ event::ptr primitive_inst::realloc_if_needed() {
760803
if (present_layout.data_padding.get_dynamic_pad_dims().sizes()[sequence_axis_legacy] == 1) {
761804
// Apply padding of variable to make it be optimized in the next iteration
762805
auto max_pad = kv_cache_inst::get_max_pad(present_layout,
763-
updated_params.output_layouts[0].get_buffer_size().count(),
806+
_max_output_layout_count[0],
764807
sequence_axis_legacy,
765808
"present_layout");
766809
if (max_pad > 0) {
@@ -783,7 +826,7 @@ event::ptr primitive_inst::realloc_if_needed() {
783826
GPU_DEBUG_TRACE_DETAIL << id() << ": Update variable " << variable.get_name()
784827
<< "'s layout with allocated kv cache output: " << present_layout.to_short_string()
785828
<< " (is_set = " << variable.is_set() << ") " << std::endl;
786-
variable.set_layout(present_layout);
829+
variable.set_memory(_outputs[0], present_layout);
787830
}
788831
} else {
789832
GPU_DEBUG_TRACE_DETAIL << id() << ": Update variable " << variable.get_name()
@@ -1036,8 +1079,10 @@ void primitive_inst::update_paddings() {
10361079
auto reset_pad = [](kernel_impl_params& params, const program_node* node) {
10371080
params.output_layouts[0].data_padding = node->get_output_layout(0).data_padding;
10381081
};
1039-
if (_node->is_type<read_value>()) {
1040-
auto& variable = get_network().get_variable(_node->as<read_value>().get_primitive()->variable_id);
1082+
if (_node->is_type<read_value>() || _node->is_type<kv_cache>()) {
1083+
auto variable_id = _node->is_type<read_value>() ? (_node->as<read_value>().get_primitive()->variable_id)
1084+
: (_node->as<kv_cache>().get_primitive()->variable_info.variable_id);
1085+
auto& variable = get_network().get_variable(variable_id);
10411086
// Reset paddings for read_value and users with dynamic pad when variable is reset
10421087
// to avoid wrong pad used for some nodes due to pad propagation logic (which uses previous iter pad values)
10431088
if (!variable.is_set()) {
@@ -1054,6 +1099,7 @@ void primitive_inst::update_paddings() {
10541099
}
10551100
return;
10561101
}
1102+
10571103
if (_node->is_type<gather>() && _impl_params->output_layouts[0].data_padding.get_dynamic_pad_dims() != tensor(0)) {
10581104
if (can_be_optimized())
10591105
_impl_params->output_layouts[0] = _impl_params->input_layouts[0];
@@ -1141,7 +1187,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
11411187
if (_impl_params->get_input_layout(0).count() == 0) {
11421188
return;
11431189
}
1144-
auto desc = _node->as<kv_cache>().get_primitive();
1190+
const auto& desc = _node->as<kv_cache>().get_primitive();
11451191
auto& past_layout = _impl_params->input_layouts[0];
11461192
auto& present_layout = _impl_params->output_layouts[0];
11471193
const auto& sequence_axis = desc->concat_axis;

src/plugins/intel_gpu/src/runtime/shape_predictor.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ std::pair<bool, ov::Shape> ShapePredictor::predict_preallocation_shape(const std
6060
const cldnn::layout& layout,
6161
bool can_reuse_buffer,
6262
const size_t out_idx,
63-
int32_t custom_next_iters_prealloc_count) {
63+
int32_t custom_next_iters_prealloc_count,
64+
int32_t custom_prealloc_dim) {
6465
size_t next_iters_prealloc_count = custom_next_iters_prealloc_count > 0
6566
? static_cast<size_t>(custom_next_iters_prealloc_count)
6667
: _next_iters_preallocation_count;
@@ -79,6 +80,13 @@ std::pair<bool, ov::Shape> ShapePredictor::predict_preallocation_shape(const std
7980
if (can_reuse_buffer)
8081
return {false, {}};
8182

83+
// If both prealloc dim and prealloc count are specified, dont predict and just use the given info
84+
if (custom_prealloc_dim >= 0 && custom_next_iters_prealloc_count > 0) {
85+
auto new_shape = current_shape;
86+
new_shape[custom_prealloc_dim] += custom_next_iters_prealloc_count;
87+
return {true, new_shape};
88+
}
89+
8290
// Check if there is enough data for prediction
8391
const auto& shapes = _shapes_info[id_record];
8492
const auto shapes_num = shapes.size();

0 commit comments

Comments
 (0)