Skip to content

Commit c801f4e

Browse files
[GPU] Relax UnsqueezeBroadcastReshapeSDPAFusion (openvinotoolkit#27515)
### Details: - By relaxing UnsqueezeBroadcastReshapeSDPAFusion, GQA pattern is enabled and Broadcasting nodes overheads in paths of key and value are removed, thus improves performance of GLM4 model significantly. - Fix for GLM4V, which has initial state shape (-1, 0, 0, 0), and shape infer failed. ### Tickets: - *CVS-157263* --------- Co-authored-by: Chen Peter <peter.chen@intel.com>
1 parent 0f149e3 commit c801f4e

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,21 @@ std::vector<ov::PartialShape> shape_infer(const KVCache* op, const std::vector<o
106106

107107
const auto& gather_axis = op->get_gather_axis();
108108
const auto& concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size());
109+
// We update output shape with input1 shape by default, as input1 is always new, and in some situations, input0 shape
110+
// has zeros in some dimensions. For example to concat input0 [-1, 0, 0, 0] + input1 [-1, 4, -1, 128] along axis 2,
111+
// we could (and should) infer dim value of axis 1 and 3 in this case.
109112
if (op->get_output_size() >= 2) {
110-
out_shapes[0] = input_shapes[0];
113+
out_shapes[0] = input_shapes[1];
111114
out_shapes[0][gather_axis] = input_shapes[2][0];
112-
out_shapes[0][concat_axis] += input_shapes[1][concat_axis];
115+
out_shapes[0][concat_axis] += input_shapes[0][concat_axis];
113116

114117
std::vector<ov::Dimension> dims(out_shapes[0].size(), 1);
115118
dims[gather_axis] = out_shapes[0][gather_axis];
116119
dims[concat_axis] = out_shapes[0][concat_axis];
117120
out_shapes[1] = dims;
118121
} else {
119-
out_shapes[0] = input_shapes[0];
120-
out_shapes[0][concat_axis] += input_shapes[1][concat_axis];
122+
out_shapes[0] = input_shapes[1];
123+
out_shapes[0][concat_axis] += input_shapes[0][concat_axis];
121124
}
122125

123126
return out_shapes;

src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,12 @@ std::vector<ov::PartialShape> shape_infer(const SDPA* op,
144144
if (is_broadcastable) {
145145
size_t max_rank = shape_q_t.size();
146146
for (size_t i = 0; i < max_rank; ++i) {
147-
if (shape_q_t[i].is_static() && shape_k_t[i].is_static() && shape_v_t[i].is_static()) {
147+
if (shape_q_t[i].is_static() && shape_k_t[i].is_static()) {
148148
auto broadcasted_dim = shape_q_t[i].get_length();
149149
shape_k_t[i] = broadcasted_dim;
150+
}
151+
if (shape_q_t[i].is_static() && shape_v_t[i].is_static()) {
152+
auto broadcasted_dim = shape_q_t[i].get_length();
150153
shape_v_t[i] = broadcasted_dim;
151154
}
152155
}

src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ using ov::pass::pattern::op::Or;
2323
UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
2424
using namespace ov::pass::pattern;
2525

26-
auto not_reshape = [](const ov::Output<ov::Node>& output) -> bool {
27-
return std::dynamic_pointer_cast<ov::op::v1::Reshape>(output.get_node_shared_ptr()) == nullptr;
28-
};
29-
3026
auto unsqueeze_predicate = [](const ov::Output<ov::Node>& output) -> bool {
3127
return rank_equals(5)(output) && consumers_count(1);
3228
};
@@ -42,7 +38,7 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
4238
return rank_equals(4)(output) && consumers_count(1);
4339
};
4440

45-
auto input_a_m = any_input(not_reshape);
41+
auto input_a_m = any_input();
4642
auto input_attn_mask = any_input();
4743
auto input_scale = any_input();
4844
auto input_b_m = wrap_type<ov::intel_gpu::op::KVCache>({any_input(), any_input()});

0 commit comments

Comments
 (0)