Skip to content

Commit a33afe4

Browse files
authored
[GPU] Fix Crop->Reshape (Squeeze/Unsqueeze modes) buffer optimization (openvinotoolkit#25836)
These changes fix a significant accuracy issue (reducing perplexity from 120 000 to 17) for Llama models with precalculated constant sin/cos values. However, there is still a problem with sin/cos representation in FP16 precision, which will be addressed in a separate PR. ### Details: - Fixed Crop->Reshape (Squeeze/Unsqueeze modes) buffer optimization - Update rope_ref kernel to support dynamic paddings for cos/sin inputs - Fix propagate_padding() function and update shape infer tests ### Tickets: - [CVS-148220](https://jira.devtools.intel.com/browse/CVS-148220), [CVS-146283](https://jira.devtools.intel.com/browse/CVS-146283)
1 parent b2319a5 commit a33afe4

File tree

10 files changed

+249
-55
lines changed

10 files changed

+249
-55
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp

+97-29
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,12 @@ bool crop_in_place_optimization::match(const program_node& node,
512512
if (node.get_program().is_body_program() && node.get_dependency(0).is_type<lstm_elt>()) {
513513
return false;
514514
}
515+
516+
GPU_DEBUG_GET_INSTANCE(debug_config);
517+
GPU_DEBUG_IF(debug_config->disable_runtime_buffer_fusing && node.is_dynamic()) {
518+
return false;
519+
}
520+
515521
// optimization is available for cropping across depth(features) or batch
516522
// if output padding has defined padding across features already it wouldn't
517523
// work because it expect to have zeros in the padded area.
@@ -553,18 +559,22 @@ bool crop_in_place_optimization::optimize(crop_node& node) {
553559
node.get_primitive()->axis,
554560
false);
555561
} else if (can_crop_be_optimized_simple_data_format(crop_layout, input_layout)) {
556-
std::vector<layout> reshape_layouts;
557-
if (node.get_users().front()->is_type<reshape>() && node.get_users().front()->as<reshape>().is_runtime_propagatable_padding()) {
558-
reshape_layouts.push_back(node.get_users().front()->get_output_layout());
562+
std::pair<const program_node*, layout> user_info;
563+
if (node.get_users().front()->is_type<reshape>()) {
564+
auto& reshape_node = node.get_users().front()->as<reshape>();
565+
if (reshape_node.is_runtime_propagatable_padding()) {
566+
user_info.first = &reshape_node;
567+
user_info.second = reshape_node.get_output_layout();
568+
}
559569
}
560570
update_in_place_crop_padding_simple_data_format(crop_layout,
561571
input_layout,
562-
reshape_layouts,
572+
user_info,
563573
crop_params->input_offsets[0],
564574
node.get_primitive()->axis,
565575
false);
566-
if (reshape_layouts.size() > 0) {
567-
node.get_users().front()->set_output_layout(reshape_layouts[0]);
576+
if (user_info.first) {
577+
node.get_users().front()->set_output_layout(user_info.second);
568578
}
569579
}
570580
node.set_output_layout(crop_layout);
@@ -632,24 +642,51 @@ void crop_in_place_optimization::update_in_place_crop_padding_along_feature(cons
632642

633643
void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(layout& crop_layout,
634644
layout& input_layout,
635-
std::vector<layout>& user_layouts,
645+
std::pair<const program_node*, layout>& user_info,
636646
const tensor offsets,
637647
size_t crop_axis,
638648
bool is_runtime) {
639-
auto crop_axis_legacy = crop_axis;
640-
if (crop_axis_legacy >= 2) {
641-
auto spatial_axis = crop_axis_legacy - 2;
642-
// Default and minimum number of dimensions is 4
643-
auto spatial_size = std::max<size_t>(crop_layout.get_partial_shape().size(), 4) - 2;
644-
crop_axis_legacy = spatial_size - spatial_axis - 1 + 2;
645-
}
649+
auto convert_axis_to_legacy = [](size_t axis, size_t rank) {
650+
auto axis_legacy = axis;
651+
if (axis_legacy >= 2) {
652+
auto spatial_axis = axis_legacy - 2;
653+
// Default and minimum number of dimensions is 4
654+
auto spatial_size = std::max<size_t>(rank, 4) - 2;
655+
axis_legacy = spatial_size - spatial_axis - 1 + 2;
656+
}
657+
658+
return axis_legacy;
659+
};
660+
661+
auto crop_axis_legacy = convert_axis_to_legacy(crop_axis, crop_layout.get_partial_shape().size());
662+
646663
// If it's build-time and node is dynamic, only dynamic padding is set first
647664
if ((crop_layout.is_dynamic() || input_layout.is_dynamic()) && !is_runtime) {
648665
auto dyn_pad_sizes = tensor(0).sizes();
649666
dyn_pad_sizes[crop_axis_legacy] = 1;
650667
crop_layout.data_padding.set_dynamic_pad(tensor(dyn_pad_sizes));
651-
for (auto& user_layout : user_layouts) {
652-
user_layout.data_padding.set_dynamic_pad(tensor(dyn_pad_sizes));
668+
669+
if (user_info.first && user_info.first->is_type<reshape>()) {
670+
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
671+
auto reshape_mode = reshape_desc->mode;
672+
if (reshape_mode == reshape::reshape_mode::base) {
673+
user_info.second.data_padding.set_dynamic_pad(tensor(dyn_pad_sizes));
674+
} else if (reshape_mode == reshape::reshape_mode::unsqueeze || reshape_mode == reshape::reshape_mode::squeeze) {
675+
auto reshape_ps = user_info.second.get_partial_shape();
676+
auto output_pattern = reshape_desc->output_pattern;
677+
678+
auto reshape_axis = crop_axis;
679+
for (size_t i = 0; i < output_pattern.size(); i++) {
680+
if (output_pattern[i] <= static_cast<int64_t>(reshape_axis)) {
681+
reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1;
682+
}
683+
}
684+
685+
auto dyn_pad_mask = tensor(0).sizes();
686+
auto reshape_axis_legacy = convert_axis_to_legacy(reshape_axis, reshape_ps.size());
687+
dyn_pad_mask[reshape_axis_legacy] = 1;
688+
user_info.second.data_padding.set_dynamic_pad(tensor(dyn_pad_mask));
689+
}
653690
}
654691
return;
655692
}
@@ -673,14 +710,40 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
673710
auto dyn_pad_sizes = lower_sizes;
674711
dyn_pad_sizes[crop_axis_legacy] = 1;
675712
crop_layout.data_padding = padding(lower_sizes, upper_sizes, 0.f, tensor(dyn_pad_sizes));
676-
for (auto& user_layout : user_layouts) {
677-
auto reshape_rank = user_layout.get_partial_shape().size();
678-
auto reshape_last_dim = user_layout.get_partial_shape().to_shape()[reshape_rank - 1];
679-
if (lower_sizes[crop_axis_legacy])
680-
lower_sizes[crop_axis_legacy] /= reshape_last_dim;
681-
if (upper_sizes[crop_axis_legacy])
682-
upper_sizes[crop_axis_legacy] /= reshape_last_dim;
683-
user_layout.data_padding = padding(lower_sizes, upper_sizes, 0.f, tensor(dyn_pad_sizes));
713+
if (user_info.first) {
714+
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
715+
auto reshape_mode = reshape_desc->mode;
716+
if (reshape_mode == reshape::reshape_mode::base) {
717+
auto reshape_rank = user_info.second.get_partial_shape().size();
718+
auto reshape_last_dim = user_info.second.get_partial_shape().to_shape()[reshape_rank - 1];
719+
if (lower_sizes[crop_axis_legacy])
720+
lower_sizes[crop_axis_legacy] /= reshape_last_dim;
721+
if (upper_sizes[crop_axis_legacy])
722+
upper_sizes[crop_axis_legacy] /= reshape_last_dim;
723+
user_info.second.data_padding = padding(lower_sizes, upper_sizes, 0.f, tensor(dyn_pad_sizes));
724+
} else {
725+
auto reshape_ps = user_info.second.get_partial_shape();
726+
auto output_pattern = reshape_desc->output_pattern;
727+
728+
auto reshape_axis = crop_axis;
729+
for (size_t i = 0; i < output_pattern.size(); i++) {
730+
if (output_pattern[i] <= static_cast<int64_t>(reshape_axis)) {
731+
reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1;
732+
}
733+
}
734+
735+
const auto output_rank = std::max(reshape_ps.size(), static_cast<size_t>(4));
736+
std::vector<int32_t> reshape_lower_sizes(output_rank, 0);
737+
std::vector<int32_t> reshape_upper_sizes(output_rank, 0);
738+
std::vector<int32_t> reshape_dyn_pad_mask(output_rank, 0);
739+
740+
const auto reshape_axis_legacy = convert_axis_to_legacy(reshape_axis, reshape_ps.size());
741+
reshape_lower_sizes[reshape_axis_legacy] = lower_sizes[crop_axis_legacy];
742+
reshape_upper_sizes[reshape_axis_legacy] = upper_sizes[crop_axis_legacy];
743+
reshape_dyn_pad_mask[reshape_axis_legacy] = 1;
744+
745+
user_info.second.data_padding = padding(reshape_lower_sizes, reshape_upper_sizes, 0.f, tensor(reshape_dyn_pad_mask));
746+
}
684747
}
685748
} else {
686749
crop_layout.data_padding = padding(lower_sizes, upper_sizes);
@@ -743,18 +806,23 @@ void prepare_buffer_fusing::run(program& p) {
743806
node.get_primitive()->axis,
744807
false);
745808
} else if (crop_in_place_optimization::can_crop_be_optimized_simple_data_format(crop_layout, pred_layout)) {
809+
std::pair<const program_node*, layout> user_info;
746810
std::vector<layout> reshape_layouts;
747-
if (node.get_users().front()->is_type<reshape>() && node.get_users().front()->as<reshape>().is_runtime_propagatable_padding()) {
748-
reshape_layouts.push_back(node.get_users().front()->get_output_layout());
811+
if (node.get_users().front()->is_type<reshape>()) {
812+
auto& reshape_node = node.get_users().front()->as<reshape>();
813+
if (reshape_node.is_runtime_propagatable_padding()) {
814+
user_info.first = &reshape_node;
815+
user_info.second = reshape_node.get_output_layout();
816+
}
749817
}
750818
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(crop_layout,
751819
pred_layout,
752-
reshape_layouts,
820+
user_info,
753821
crop_params->input_offsets[0],
754822
node.get_primitive()->axis,
755823
false);
756-
if (reshape_layouts.size() > 0) {
757-
node.get_users().front()->set_output_layout(reshape_layouts[0]);
824+
if (user_info.first) {
825+
node.get_users().front()->set_output_layout(user_info.second);
758826
}
759827
}
760828
node.set_output_layout(crop_layout);

src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ struct crop_in_place_optimization : pattern_match_optimization_typed<crop_in_pla
8282
bool is_runtime);
8383
static void update_in_place_crop_padding_simple_data_format(layout& crop_layout,
8484
layout& pred_layout,
85-
std::vector<layout>& user_layouts,
85+
std::pair<const program_node*, layout>& user_info,
8686
const tensor offsets,
8787
size_t crop_axis,
8888
bool is_runtime);

src/plugins/intel_gpu/src/graph/include/reshape_inst.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
3232

3333
bool is_runtime_propagatable_padding() const {
3434
auto prim = typed_desc();
35-
if (prim->mode == reshape::reshape_mode::squeeze || prim->mode == reshape::reshape_mode::unsqueeze)
36-
return true;
35+
if (prim->mode == reshape::reshape_mode::squeeze || prim->mode == reshape::reshape_mode::unsqueeze) {
36+
// For proper padding propagation we need to know output pattern at model loading stage
37+
// in case of squeeze/unsqueeze mode
38+
return prim->output_pattern.size() > 0;
39+
}
3740

3841
// TODO: This function is to limit condition to a specific case (crop + reshape) among cases for the base mode
3942
if (!input().is_type<crop>())

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1485,15 +1485,16 @@ void primitive_inst::do_runtime_in_place_crop() {
14851485
u->update_shape_done_by_other = true;
14861486

14871487
const auto& crop_users = u->get_user_insts();
1488-
std::vector<layout> reshape_layouts;
1488+
std::pair<const program_node*, layout> user_info;
14891489
if (crop_users.front()->get_node().is_type<reshape>()) {
14901490
OPENVINO_ASSERT(crop_users.size() == 1, "[GPU] Expected number of reshape users is 1, but it is ", crop_users.size());
14911491
auto reshape_inst = crop_users.front();
14921492
if (!reshape_inst->update_shape_done_by_other) {
14931493
GPU_DEBUG_TRACE_DETAIL << "[In place crop] update shape for " << reshape_inst->id() << std::endl;
14941494
reshape_inst->update_shape();
14951495
reshape_inst->update_shape_done_by_other = true;
1496-
reshape_layouts.push_back(reshape_inst->_impl_params->get_output_layout());
1496+
user_info.first = &reshape_inst->get_node();
1497+
user_info.second = reshape_inst->_impl_params->get_output_layout();
14971498
}
14981499
}
14991500

@@ -1510,11 +1511,10 @@ void primitive_inst::do_runtime_in_place_crop() {
15101511
if (crop_in_place_optimization::can_crop_be_optimized_along_feature(crop_layout, pred_layout)) {
15111512
crop_in_place_optimization::update_in_place_crop_padding_along_feature(u->get_node(), crop_layout, pred_layout, offsets, crop_axis, true);
15121513
} else if (crop_in_place_optimization::can_crop_be_optimized_simple_data_format(crop_layout, pred_layout)) {
1513-
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(crop_layout, pred_layout, reshape_layouts,
1514-
offsets, crop_axis, true);
1515-
if (crop_users.front()->get_node().is_type<reshape>() && reshape_layouts.size() > 0) {
1514+
crop_in_place_optimization::update_in_place_crop_padding_simple_data_format(crop_layout, pred_layout, user_info, offsets, crop_axis, true);
1515+
if (user_info.first) {
15161516
auto reshape_inst = crop_users.front();
1517-
reshape_inst->_impl_params->output_layouts[0] = reshape_layouts[0];
1517+
reshape_inst->_impl_params->output_layouts[0] = user_info.second;
15181518
reshape_inst->set_shape_change();
15191519
}
15201520
} else {

src/plugins/intel_gpu/src/graph/reshape.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_s
5151
update_pad_upper = pad_upper;
5252
update_pad_mask = pad_mask;
5353

54+
// Truncate to the actual rank (for shapes with a rank less than 4)
55+
update_pad_lower.resize(rank);
56+
update_pad_upper.resize(rank);
57+
update_pad_mask.resize(rank);
58+
5459
std::unordered_set<int64_t> tmp(axes.begin(), axes.end());
5560
std::vector<int64_t> unique_axes;
5661
const auto expanded_rank = rank + tmp.size();
@@ -61,13 +66,13 @@ padding propagate_padding(const layout& in_layout, const ov::PartialShape& out_s
6166
// Normalize then remove repeated axes after normalization.
6267
for (const auto& axis : axes) {
6368
if (static_cast<size_t>(axis) <= out_shape.size()) {
64-
pad_lower.insert(std::next(std::begin(pad_lower), axis), 0);
65-
pad_upper.insert(std::next(std::begin(pad_upper), axis), 0);
66-
pad_mask.insert(std::next(std::begin(pad_mask), axis), 0);
69+
update_pad_lower.insert(std::next(std::begin(update_pad_lower), axis), 0);
70+
update_pad_upper.insert(std::next(std::begin(update_pad_upper), axis), 0);
71+
update_pad_mask.insert(std::next(std::begin(update_pad_mask), axis), 0);
6772
} else {
68-
pad_lower.push_back(0);
69-
pad_upper.push_back(0);
70-
pad_mask.push_back(0);
73+
update_pad_lower.push_back(0);
74+
update_pad_upper.push_back(0);
75+
update_pad_mask.push_back(0);
7176
}
7277
}
7378
} else {
@@ -254,6 +259,7 @@ std::string reshape_inst::to_string(reshape_node const& node) {
254259
reshape_info.add("output pshape", desc->output_partial_shape);
255260
reshape_info.add("output pattern", desc->output_pattern);
256261
reshape_info.add("special zero", desc->special_zero);
262+
reshape_info.add("reshape mode", desc->mode);
257263

258264
node_info->add("reshape info", reshape_info);
259265
node_info->dump(primitive_description);

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl

+24-6
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,26 @@ KERNEL(rope_ref)(
7171
uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
7272
uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0;
7373
uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0;
74+
75+
#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
7476
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0);
7577

78+
uint cos_idx = cos_sin_idx;
79+
uint sin_idx = cos_sin_idx;
80+
#else
81+
uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0);
82+
uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_p, cos_sin_h, 0);
83+
#endif
84+
7685
uint output_idx = OUTPUT_GET_INDEX(b, p, h, 0);
7786

7887
INPUT0_TYPE in1 = input[input_idx + r];
7988
INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r];
8089

81-
output[output_idx + r] = cos[cos_sin_idx + r] * in1 - sin[cos_sin_idx + r] * in2;
90+
output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2;
8291

83-
output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in2 +
84-
sin[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in1;
92+
output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 +
93+
sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1;
8594
}
8695
#endif
8796

@@ -128,16 +137,25 @@ KERNEL(rope_ref)(
128137
cos_sin_p = gather[gather_idx];
129138
#endif
130139
cos_sin_p = cos_sin_p < INPUT1_SIZE_Y ? cos_sin_p : 0;
140+
141+
#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
131142
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);
132143

144+
uint cos_idx = cos_sin_idx;
145+
uint sin_idx = cos_sin_idx;
146+
#else
147+
uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);
148+
uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0);
149+
#endif
150+
133151
uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0);
134152

135153
INPUT0_TYPE in1 = input[input_idx + r];
136154
INPUT0_TYPE in2 = input[input_idx + HALF_ROTARY_NDIMS + r];
137155

138-
output[output_idx + r] = cos[cos_sin_idx + r] * in1 - sin[cos_sin_idx + r] * in2;
156+
output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2;
139157

140-
output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in2 +
141-
sin[cos_sin_idx + HALF_ROTARY_NDIMS + r] * in1;
158+
output[output_idx + HALF_ROTARY_NDIMS + r] = cos[cos_idx + HALF_ROTARY_NDIMS + r] * in2 +
159+
sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1;
142160
}
143161
#endif

0 commit comments

Comments
 (0)