Skip to content

Commit bb7f8d3

Browse files
authored
[GPU] Minor refactoring (openvinotoolkit#25629)
### Details: - Refactor according to the comments in PR25449 ### Tickets: - *ticket-id*
1 parent eeb8fe9 commit bb7f8d3

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

src/plugins/intel_gpu/include/intel_gpu/graph/kernel_impl_params.hpp

+9
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ struct kernel_impl_params final {
114114
return output_layouts[idx];
115115
}
116116

117+
layout& get_output_layout(size_t idx = 0) {
118+
OPENVINO_ASSERT(output_layouts.size() > idx,
119+
"The size of output layouts must be greater than the requested index: ",
120+
"Requested index is ", idx, ",",
121+
"but the size of output layouts is ", output_layouts.size());
122+
return output_layouts[idx];
123+
}
124+
125+
117126
bool has_fused_primitives() const { return !fused_desc.empty(); }
118127

119128
ov::element::Type_t get_output_element_type() const {

src/plugins/intel_gpu/include/intel_gpu/runtime/layout.hpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,15 @@ struct layout {
288288
return *this;
289289
}
290290

291+
layout clone_with_other_shape(const ov::PartialShape& new_shape) {
292+
return layout(new_shape, this->data_type, this->format, this->data_padding);
293+
}
294+
295+
layout clone_with_other_shape(const ov::Shape& new_shape) {
296+
return clone_with_other_shape(ov::PartialShape(new_shape));
297+
}
298+
299+
291300
friend bool operator==(const layout& lhs, const layout& rhs) {
292301
return lhs.data_type == rhs.data_type && lhs.format == rhs.format && lhs.size == rhs.size && lhs.data_padding == rhs.data_padding;
293302
}
@@ -306,7 +315,7 @@ struct layout {
306315
return (lhs.data_padding < rhs.data_padding);
307316
}
308317

309-
/// Number of elements to be stored in this memory layout
318+
/// Number of elements to be stored in this layout
310319
size_t count() const;
311320

312321
/// Layout size with padding included

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+12-14
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ void primitive_inst::update_shape() {
465465
auto desc = get_node().as<kv_cache>().get_primitive();
466466
auto var_mem_size = get_network().get_variable(desc->variable_info.variable_id).get_actual_mem_size();
467467
// Need to trigger realloc_if_needed
468-
if (var_mem_size < _impl_params->get_output_layout(0).get_buffer_size().count())
468+
if (var_mem_size < _impl_params->get_output_layout(0).get_linear_size())
469469
set_shape_change();
470470
}
471471
}
@@ -684,13 +684,13 @@ event::ptr primitive_inst::realloc_if_needed() {
684684
prealloc_shape[seq_axis] += tmp_prealloc_count;
685685
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
686686
} else {
687-
required_buffer_size = (updated_layouts[i].get_buffer_size().count());
687+
required_buffer_size = (updated_layouts[i].get_linear_size());
688688
}
689689
if (required_buffer_size * 10 < _max_output_layout_count[i]) {
690690
reclaim = true;
691691
}
692692
if (reclaim) {
693-
GPU_DEBUG_TRACE_DETAIL << id() << ": Updated output[" << i << "] size " << updated_layouts[i].get_buffer_size().count()
693+
GPU_DEBUG_TRACE_DETAIL << id() << ": Updated output[" << i << "] size " << updated_layouts[i].get_linear_size()
694694
<< " is much smaller than current memory size! " << _max_output_layout_count[i]
695695
<< "Reset memory of output " << i << std::endl;
696696
_max_output_layout_count[i] = 0;
@@ -705,7 +705,7 @@ event::ptr primitive_inst::realloc_if_needed() {
705705
}
706706

707707
for (size_t i = 0; i < actual_layouts.size(); ++i) {
708-
bool can_reuse_buffer = (_outputs[i] && updated_layouts[i].get_buffer_size().count() <= _max_output_layout_count[i]);
708+
bool can_reuse_buffer = (_outputs[i] && updated_layouts[i].get_linear_size() <= _max_output_layout_count[i]);
709709
std::pair<bool, ov::Shape> prealloc_info;
710710
if (_node->is_type<kv_cache>() && i == 0) {
711711
const auto& desc = _node->as<kv_cache>().get_primitive();
@@ -717,17 +717,15 @@ event::ptr primitive_inst::realloc_if_needed() {
717717
prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
718718
}
719719
if (prealloc_info.first && sp.can_preallocate(ov::shape_size(prealloc_info.second) * (dt_sizes_in_B[i]))) {
720-
auto new_layout = updated_layouts[i];
721-
new_layout.set_partial_shape(prealloc_info.second);
722-
updated_params.output_layouts[i] = new_layout;
720+
updated_params.output_layouts[i] = updated_layouts[i].clone_with_other_shape(prealloc_info.second);
723721
}
724-
if (updated_params.output_layouts[i].get_buffer_size().count() < updated_layouts[i].get_buffer_size().count()) {
722+
if (updated_params.output_layouts[i].get_linear_size() < updated_layouts[i].get_linear_size()) {
725723
updated_params.output_layouts[i] = updated_layouts[i];
726724
}
727725

728726
if (can_reuse_buffer) {
729727
GPU_DEBUG_TRACE_DETAIL << id() << ": reuse previously allocated output buffer[" << i << "] - "
730-
<< actual_layouts[i].get_buffer_size().count() << "/" << _max_output_layout_count[i]
728+
<< actual_layouts[i].get_linear_size() << "/" << _max_output_layout_count[i]
731729
<< std::endl;
732730
if (_node->is_type<kv_cache>() && (i == 0)) {
733731
// kv_cache has already assigned memory.
@@ -759,7 +757,7 @@ event::ptr primitive_inst::realloc_if_needed() {
759757
GPU_DEBUG_TRACE_DETAIL << id() << ": realloc output memory. " << std::endl;
760758
GPU_DEBUG_TRACE_DETAIL << " outputs[" << i << "] "
761759
<< " Current buffer_size=" << _max_output_layout_count[i]
762-
<< " Requested buffer_size=" << updated_layouts[i].get_buffer_size().count()
760+
<< " Requested buffer_size=" << updated_layouts[i].get_linear_size()
763761
<< std::endl;
764762
_outputs[i] = allocate_output(_network.get_engine(),
765763
_network.get_memory_pool(),
@@ -773,7 +771,7 @@ event::ptr primitive_inst::realloc_if_needed() {
773771
is_output_buffer(this, true),
774772
output_memory_ptr(i).get(),
775773
true);
776-
_max_output_layout_count[i] = updated_params.output_layouts[i].get_buffer_size().count();
774+
_max_output_layout_count[i] = updated_params.output_layouts[i].get_linear_size();
777775
GPU_DEBUG_CODE(std::string memalloc_info = "");
778776
GPU_DEBUG_CODE(memalloc_info += (((_outputs.size() > 1) ? ("o" + to_string(i) + ":") : "") +
779777
(_outputs[i]->from_memory_pool ? "from_pool" : "new_alloc"));)
@@ -1852,7 +1850,7 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
18521850
_impl_params->strm = _network.get_stream_ptr();
18531851
for (size_t i = 0; i < get_node().get_output_layouts().size(); ++i) {
18541852
if (_outputs.size() > i) {
1855-
_max_output_layout_count.push_back(_outputs[i] ? _outputs[i]->get_layout().get_buffer_size().count() : 0);
1853+
_max_output_layout_count.push_back(_outputs[i] ? _outputs[i]->get_layout().get_linear_size() : 0);
18561854
} else {
18571855
_outputs.push_back(nullptr);
18581856
_max_output_layout_count.push_back(0);
@@ -1985,9 +1983,9 @@ event::ptr primitive_inst::update_weights() {
19851983
GPU_DEBUG_TRACE_DETAIL << id() << ": add original weights memory " << original_layout.to_short_string() << " to weights cache; "
19861984
<< "cache_size=" << _reordered_weights_cache.size() << "/" << _reordered_weights_cache.capacity() << std::endl;
19871985
} else {
1988-
auto expected_layout = reorder_kernel_params->get_output_layout();
19891986
// Set original partial shape, because it may be lost during kernel_selector::weights_tensor -> layout conversion
1990-
expected_layout.set_partial_shape(original_layout.get_partial_shape());
1987+
auto expected_layout =
1988+
reorder_kernel_params->get_output_layout().clone_with_other_shape(original_layout.get_partial_shape());
19911989
_impl_params->weights_layout = optional_layout(expected_layout);
19921990

19931991
if (_reordered_weights_cache.has(expected_layout)) {

0 commit comments

Comments
 (0)