Skip to content

Commit ea05578

Browse files
authored
[GPU] PagedAttention improvements and fixes (#29248)
### Details: - Extended IncreasePositionIdsPrecision to support Matmul followed by Reshape, improving accuracy for large context sizes - Added support for sliding window size PA parameter - Fixed PA caching: added missing scale and fixed PA primitive descriptor comparator
1 parent da0a6d3 commit ea05578

File tree

10 files changed

+108
-2
lines changed

10 files changed

+108
-2
lines changed

src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp

+32-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,18 @@ struct paged_attention : public primitive_base<paged_attention> {
3030
}
3131

3232
bool operator==(const primitive& rhs) const override {
33-
return compare_common_params(rhs);
33+
if (!compare_common_params(rhs))
34+
return false;
35+
36+
auto rhs_casted = downcast<const paged_attention>(rhs);
37+
38+
return head_size == rhs_casted.head_size &&
39+
heads_num == rhs_casted.heads_num &&
40+
kv_heads_num == rhs_casted.kv_heads_num &&
41+
sliding_window == rhs_casted.sliding_window &&
42+
has_alibi == rhs_casted.has_alibi &&
43+
has_rotated_blocks == rhs_casted.has_rotated_blocks &&
44+
scale_val.value_or(1.0f) == rhs_casted.scale_val.value_or(1.0f);
3445
}
3546

3647
void save(BinaryOutputBuffer& ob) const override {
@@ -40,6 +51,14 @@ struct paged_attention : public primitive_base<paged_attention> {
4051
ob << kv_heads_num;
4152
ob << has_alibi;
4253
ob << has_rotated_blocks;
54+
ob << sliding_window;
55+
56+
if (scale_val.has_value()) {
57+
ob << true;
58+
ob << scale_val.value();
59+
} else {
60+
ob << false;
61+
}
4362
}
4463

4564
void load(BinaryInputBuffer& ib) override {
@@ -49,12 +68,24 @@ struct paged_attention : public primitive_base<paged_attention> {
4968
ib >> kv_heads_num;
5069
ib >> has_alibi;
5170
ib >> has_rotated_blocks;
71+
ib >> sliding_window;
72+
73+
bool has_scale;
74+
ib >> has_scale;
75+
if (has_scale) {
76+
float scale = 1.0f;
77+
ib >> scale;
78+
scale_val = scale;
79+
} else {
80+
scale_val = std::optional<float>();
81+
}
5282
}
5383

5484
std::optional<float> scale_val{};
5585
size_t head_size = 0;
5686
size_t heads_num = 0;
5787
size_t kv_heads_num = 0;
88+
size_t sliding_window = 0;
5889
bool has_alibi = false;
5990
bool has_rotated_blocks = false;
6091
};

src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
665665
config.is_causal = true;
666666
config.is_paged_attention = true;
667667
config.paged_attention_block_size = static_cast<int64_t>(paged_attention::block_size);
668+
config.paged_attention_sliding_window = desc->sliding_window;
668669

669670
if (desc->scale_val.has_value()) {
670671
config.has_const_scale_val = true;

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+4
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ KERNEL(pa_sdpa_opt)(
197197
qk_acc += alibi_slopes[head_num_idx] * alibi_val;
198198
#endif
199199

200+
#if SLIDING_WINDOW_SIZE != 0
201+
if (token_idx >= seq_len || (seq_len > SLIDING_WINDOW_SIZE && token_idx < (seq_len - SLIDING_WINDOW_SIZE)))
202+
#else
200203
if (token_idx >= seq_len)
204+
#endif
201205
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;
202206

203207
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl

+9
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,12 @@ inline MASK_VECTOR_TYPE FUNC(load_attn_mask)(OPTIONAL_SHAPE_INFO_ARG
784784
}
785785
} else {
786786
for (uint i = 0; i < SUBGROUP_SIZE; i++) {
787+
#if defined(IS_PAGED_ATTENTION) && SLIDING_WINDOW_SIZE != 0
788+
if ((source_seq_idx + i > target_seq_idx) ||
789+
(target_seq_idx >= SLIDING_WINDOW_SIZE && source_seq_idx + i < target_seq_idx - SLIDING_WINDOW_SIZE))
790+
#else
787791
if (source_seq_idx + i > target_seq_idx)
792+
#endif
788793
mask_vec[i] = NAN;
789794
}
790795
}
@@ -1167,7 +1172,11 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
11671172
unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
11681173
#if IS_CAUSAL
11691174
// casual mask: valid only if m >= n
1175+
#if defined(IS_PAGED_ATTENTION) && SLIDING_WINDOW_SIZE != 0
1176+
if ((seq_len + i <= target_seq_idx + sglid) && (target_seq_idx + sglid < SLIDING_WINDOW_SIZE || seq_len + i >= target_seq_idx + sglid - SLIDING_WINDOW_SIZE)) {
1177+
#else
11701178
if (seq_len + i <= target_seq_idx + sglid) {
1179+
#endif
11711180
#endif // IS_CAUSAL
11721181
#if !APPLY_SCALES_TO_QUERY
11731182
#if HAS_SCALE_INPUT

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
206206
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", seq_len_partition_size));
207207
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
208208
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
209+
jit.AddConstant(MakeJitConstant("SLIDING_WINDOW_SIZE", config.paged_attention_sliding_window));
209210

210211
if (config.broadcast_axis != -1) {
211212
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size));

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ struct sdpa_configuration {
9595

9696
// Paged Attention configuration
9797
bool is_paged_attention = false;
98+
size_t paged_attention_sliding_window = 0;
9899
int64_t paged_attention_aligned_seq_len = -1;
99100
int64_t paged_attention_block_size = 0;
100101
int64_t paged_attention_max_len = 0;

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
186186
jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx)));
187187

188188
if (params.conf.is_paged_attention) {
189+
jit.AddConstant(MakeJitConstant("SLIDING_WINDOW_SIZE", params.conf.paged_attention_sliding_window));
189190
if (params.conf.has_alibi_input) {
190191
jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1));
191192
}

src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
5252
prim.heads_num = heads_num;
5353

5454
const size_t scale_idx = 9;
55+
const size_t sliding_window_idx = 10;
5556
const size_t alibi_idx = 11;
5657

5758
std::shared_ptr<ov::op::v0::Constant> scale_const = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(scale_idx));
@@ -62,6 +63,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
6263
prim.scale_val = std::optional<float>();
6364
}
6465

66+
std::shared_ptr<ov::op::v0::Constant> sliding_windows_const = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(sliding_window_idx));
67+
if (sliding_windows_const) {
68+
OPENVINO_ASSERT(ov::shape_size(sliding_windows_const->get_output_shape(0)) == 1);
69+
prim.sliding_window = sliding_windows_const->cast_vector<size_t>()[0];
70+
}
71+
6572
std::shared_ptr<ov::op::v0::Constant> alibi_const = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(alibi_idx));
6673
OPENVINO_ASSERT(alibi_const != nullptr);
6774
prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0;

src/plugins/intel_gpu/src/plugin/transformations/increase_position_ids_precision.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {
3131

3232
auto gemm_or_matmul = wrap_type<ov::intel_gpu::op::Gemm, ov::op::v0::MatMul>();
3333
auto transpose_m = wrap_type<ov::op::v1::Transpose>({gemm_or_matmul, any_input()});
34-
auto concat_input = std::make_shared<Or>(OutputVector{gemm_or_matmul, transpose_m});
34+
auto reshape_m = wrap_type<ov::op::v1::Reshape>({gemm_or_matmul, any_input()});
35+
auto concat_input = std::make_shared<Or>(OutputVector{gemm_or_matmul, transpose_m, reshape_m});
3536
auto concat = wrap_type<ov::op::v0::Concat>({concat_input, concat_input});
3637
auto sin = wrap_type<ov::op::v0::Sin>({concat});
3738
auto cos = wrap_type<ov::op::v0::Cos>({concat});

src/plugins/intel_gpu/tests/unit/transformations/increase_position_ids_precision_test.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,53 @@ TEST_F(TransformationTestsF, IncreasePositionIdsMatmulWithoutUnsqueeze) {
176176
}
177177
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
178178
}
179+
180+
TEST_F(TransformationTestsF, IncreasePositionIdsReshapeAfterMatmul) {
181+
{
182+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
183+
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
184+
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
185+
auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
186+
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});
187+
auto reshape_dims = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{3});
188+
189+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_fp, rotary_embd_const);
190+
auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, reshape_dims, true);
191+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2);
192+
193+
auto cos = std::make_shared<ov::op::v0::Cos>(concat);
194+
auto sin = std::make_shared<ov::op::v0::Sin>(concat);
195+
196+
auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
197+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos, sin}, ov::op::internal::RoPE::Config());
198+
199+
model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, reshape_dims });
200+
manager.register_pass<IncreasePositionIdsPrecision>();
201+
}
202+
{
203+
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
204+
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
205+
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
206+
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});
207+
auto reshape_dims = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{3});
208+
209+
auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32);
210+
auto rotary_embd_const_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd_const, ov::element::f32);
211+
212+
auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_f32, rotary_embd_const_convert_f32);
213+
auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, reshape_dims, true);
214+
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2);
215+
216+
auto cos = std::make_shared<ov::op::v0::Cos>(concat);
217+
auto sin = std::make_shared<ov::op::v0::Sin>(concat);
218+
219+
auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16);
220+
auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16);
221+
222+
auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
223+
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_convert, sin_convert}, ov::op::internal::RoPE::Config());
224+
225+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, reshape_dims });
226+
}
227+
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
228+
}

0 commit comments

Comments
 (0)