Skip to content

Commit eeb8fe9

Browse files
[GPU] Fix remaininig issue to calculate present layout's padding for KVCache (openvinotoolkit#25706)
### Details: - Follow up remaining issue from openvinotoolkit#25682 - Fix issue where kvcache was optimized out even if calculated present layout's padding was negative ### Tickets: - 146876
1 parent ff54835 commit eeb8fe9

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+5-8
Original file line numberDiff line numberDiff line change
@@ -1208,14 +1208,11 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
12081208
GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " initial present_layout : " << present_layout.to_string() << std::endl;
12091209
GPU_DEBUG_TRACE_DETAIL << "[do runtime kv_cache opt] " << id() << " initial past_layout : " << past_layout.to_string() << std::endl;
12101210
auto max_pad = kv_cache_inst::get_max_pad(past_layout, _deps[0].first->_max_output_layout_count[0], sequence_axis_legacy, "past_layout");
1211-
1212-
if (max_pad > 0) {
1213-
const auto new_seq_len = static_cast<int64_t>(new_layout.get_shape()[sequence_axis]);
1214-
if (max_pad - new_seq_len >= 0) {
1215-
kv_cache_inst::update_pad(present_layout, max_pad - new_seq_len, sequence_axis_legacy);
1216-
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_layout's pad : "
1217-
<< present_layout.to_string() << std::endl;
1218-
}
1211+
const auto new_seq_len = static_cast<int64_t>(new_layout.get_shape()[sequence_axis]);
1212+
// In chatbot scenario, when chat history must be stored in kvcache, new_seq_len may not be 1 even if max_pad is greater than 0
1213+
if (max_pad - new_seq_len >= 0) {
1214+
kv_cache_inst::update_pad(present_layout, max_pad - new_seq_len, sequence_axis_legacy);
1215+
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id() << " Updated present_layout's pad : " << present_layout.to_string() << std::endl;
12191216
auto& variable = get_network().get_variable(desc->variable_info.variable_id);
12201217
variable.set_layout(present_layout);
12211218
GPU_DEBUG_TRACE_DETAIL << "[do_runtime_in_place_kv_cache] " << id() << "Updated variable with present_layout"

0 commit comments

Comments
 (0)