Skip to content

Commit d737fa2

Browse files
authored
[GPU] Extend shape_of subgraphs markup logic to include PagedAttention input (#29445)
### Details: - Extend shape_of subgraphs markup logic to include PagedAttention's `max_context_len` input - This patch fixes qwen-7b-chat performance issue, which is caused by the fact that it relies on `max_context_len` data input which is treated as "ShapeOf" subgraph data source. Considering that it isn't used anywhere except PagedAttention itself as a direct input, or some simple shape-flow calculations, add such input and its users to ShapeOf-subgraph as well. Otherwise, such subgraphs are calculated on GPU, introducing runtime synchronization and significantly dropping performance. Qwen model: ![image](https://github.com/user-attachments/assets/8d42204f-aa1a-422d-af3c-907d205c0b33) Qwen-2 model: ![image](https://github.com/user-attachments/assets/6d613e9c-5a58-41f7-8719-74192cf24c4e) ### Tickets: - [CVS-164134](https://jira.devtools.intel.com/browse/CVS-164134)
1 parent 519420f commit d737fa2

File tree

3 files changed

+119
-3
lines changed

3 files changed

+119
-3
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/mark_shape_of_subgraphs.cpp

+25-2
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,37 @@
1010
#include "select_inst.h"
1111
#include "strided_slice_inst.h"
1212
#include "gather_inst.h"
13+
#include "input_layout_inst.h"
14+
#include "paged_attention_inst.h"
1315
#include "pass_manager.h"
1416

1517
#include "intel_gpu/graph/program.hpp"
1618

1719
using namespace cldnn;
1820

19-
void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
21+
static bool is_shape_of_subgraph_root(program_node& node) {
2022
if (node.is_type<shape_of>()) {
23+
return true;
24+
}
25+
26+
// Allow input_layout to be the root of the shape_of subgraph if it's 'max_context_len'
27+
// input of PagedAttention, which can be used as a shape calculation flow source in some
28+
// models like Qwen and Qwen2
29+
if (node.is_type<input_layout>()) {
30+
const auto& users = node.get_users();
31+
for (const auto& user : users) {
32+
const auto max_context_len_input_id = 12;
33+
if (user->is_type<paged_attention>() && user->get_dependency_index(node) == max_context_len_input_id) {
34+
return true;
35+
}
36+
}
37+
}
38+
39+
return false;
40+
}
41+
42+
void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
43+
if (is_shape_of_subgraph_root(node)) {
2144
mark_node(node);
2245
return;
2346
}
@@ -102,7 +125,7 @@ void mark_shape_of_subgraphs::mark_node(program_node& node) {
102125

103126
// If current node has shape_of type add it to dependant shape_of nodes for
104127
// correct dependency propagation for users
105-
if (node.is_type<shape_of>())
128+
if (is_shape_of_subgraph_root(node))
106129
node.add_dependant_shape_of_node(&node);
107130

108131
// Add parent shape_of nodes from other dependencies if there are any

src/plugins/intel_gpu/src/graph/program_node.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ void program_node::select_preferred_formats(impl_types impl_type) {
658658
}
659659

660660
void program_node::add_dependant_shape_of_node(const program_node* node) {
661-
OPENVINO_ASSERT(node->is_type<shape_of>(), "[GPU] Expected node type is shape_of");
661+
OPENVINO_ASSERT(node->is_type<shape_of>() || node->is_type<input_layout>(), "[GPU] Expected node type is shape_of");
662662
dependant_shape_of_nodes.insert(node);
663663
}
664664

src/plugins/intel_gpu/tests/unit/passes/mark_shape_of_subgraphs_test.cpp

+93
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "select_inst.h"
1919
#include "strided_slice_inst.h"
2020
#include "broadcast_inst.h"
21+
#include "paged_attention_inst.h"
2122
#include "pass_manager.h"
2223
#include "to_string_utils.h"
2324

@@ -31,6 +32,10 @@ static bool check_subgraph(const program_node& node, const program_node& last_no
3132
if (custom_dependant_nodes_count.find(node.id()) != custom_dependant_nodes_count.end())
3233
expected_dependant_nodes = custom_dependant_nodes_count[node.id()];
3334

35+
// Skip some custom nodes if they are not intended to be included into shape_of subgraph
36+
if (expected_dependant_nodes == 0)
37+
return true;
38+
3439
if (!node.is_in_shape_of_subgraph() || node.get_dependant_shape_of_nodes().size() != expected_dependant_nodes)
3540
return false;
3641

@@ -423,3 +428,91 @@ TEST(mark_shape_of_subgraphs, broadcast_w_direct_shapeof_and_data) {
423428

424429
ASSERT_TRUE(check_subgraph(prog->get_node("shape_of"), prog->get_node("broadcast")));
425430
}
431+
432+
TEST(mark_shape_of_subgraphs, paged_attention_max_context_len_input) {
433+
auto& engine = get_test_engine();
434+
auto input_layout_dynamic = layout{ov::PartialShape{ov::Dimension::dynamic(), 4, ov::Dimension::dynamic(), ov::Dimension::dynamic()},
435+
data_types::f32, format::bfyx};
436+
auto target_shape = engine.allocate_memory({ ov::PartialShape{4}, data_types::i32, format::bfyx });
437+
set_values(target_shape, {4, 4, 1, 1});
438+
439+
auto subtract_one = engine.allocate_memory({ ov::PartialShape{1}, data_types::i32, format::bfyx });
440+
set_values(target_shape, {-1});
441+
442+
auto query_layout = layout{ov::PartialShape{ov::Dimension::dynamic(), 128},
443+
data_types::f32,
444+
format::bfyx};
445+
auto key_layout = query_layout;
446+
auto value_layout = query_layout;
447+
auto key_cache_layout = layout{ov::PartialShape{ov::Dimension::dynamic(), 2, 64, 16},
448+
data_types::f32,
449+
format::bfyx};
450+
auto dynamic_i32_layout = layout{ov::PartialShape::dynamic(1), data_types::i32, format::bfyx};
451+
auto value_cache_layout = key_cache_layout;
452+
auto past_lens_layout = dynamic_i32_layout;
453+
auto subsequence_begins_layout = dynamic_i32_layout;
454+
auto block_indices_layout = dynamic_i32_layout;
455+
auto block_indices_begins_layout = dynamic_i32_layout;
456+
auto scale_layout = layout{ov::PartialShape{1}, data_types::f32, format::bfyx};
457+
auto sliding_window_layout = layout{ov::PartialShape{}, data_types::i32, format::bfyx};
458+
auto alibi_layout = layout{ov::PartialShape{}, data_types::f32, format::bfyx};
459+
auto max_context_len_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};;
460+
461+
std::vector<input_info> pa_inputs = {input_info("query"),
462+
input_info("key"),
463+
input_info("value"),
464+
input_info("key_cache"),
465+
input_info("value_cache"),
466+
input_info("past_lens"),
467+
input_info("subsequence_begins"),
468+
input_info("block_indices"),
469+
input_info("block_indices_begins"),
470+
input_info("scale"),
471+
input_info("sliding_window"),
472+
input_info("alibi"),
473+
input_info("max_context_len")};
474+
475+
auto pa_prim = paged_attention("paged_attention", pa_inputs);
476+
pa_prim.head_size = 64;
477+
pa_prim.kv_heads_num = 2;
478+
pa_prim.heads_num = 2;
479+
pa_prim.scale_val = 1.f;
480+
pa_prim.has_alibi = false;
481+
pa_prim.num_outputs = 1;
482+
pa_prim.has_rotated_blocks = false;
483+
484+
topology topology;
485+
topology.add(input_layout("query", query_layout));
486+
topology.add(input_layout("key", key_layout));
487+
topology.add(input_layout("value", value_layout));
488+
topology.add(input_layout("key_cache", key_cache_layout));
489+
topology.add(input_layout("value_cache", value_cache_layout));
490+
topology.add(input_layout("past_lens", past_lens_layout));
491+
topology.add(input_layout("subsequence_begins", subsequence_begins_layout));
492+
topology.add(input_layout("block_indices", block_indices_layout));
493+
topology.add(input_layout("block_indices_begins", block_indices_begins_layout));
494+
topology.add(input_layout("scale", scale_layout));
495+
topology.add(input_layout("sliding_window", sliding_window_layout));
496+
topology.add(input_layout("alibi", alibi_layout));
497+
topology.add(input_layout("max_context_len", max_context_len_layout));
498+
topology.add(input_layout("input", input_layout_dynamic));
499+
topology.add(data("target_shape", target_shape));
500+
topology.add(data("subtract_one", subtract_one));
501+
topology.add(shape_of("shape_of", input_info("input"), data_types::i32));
502+
topology.add(broadcast("broadcast", input_info("shape_of"), input_info("target_shape"), {}, ov::op::BroadcastType::BIDIRECTIONAL));
503+
topology.add(eltwise("subtract_one_max_context_len", input_info("max_context_len"), input_info("subtract_one"), eltwise_mode::sum));
504+
topology.add(eltwise("updated_broadcast", input_info("broadcast"), input_info("subtract_one_max_context_len"), eltwise_mode::sum));
505+
topology.add(reshape("reshape", input_info("input"), input_info("updated_broadcast"), false, ov::PartialShape::dynamic(4)));
506+
topology.add(pa_prim);
507+
508+
ExecutionConfig config = get_test_default_config(engine);
509+
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
510+
config.set_property(ov::intel_gpu::optimize_data(true));
511+
network network(engine, topology, config);
512+
513+
auto prog = network.get_program();
514+
ASSERT_NE(prog, nullptr);
515+
516+
ASSERT_TRUE(check_subgraph(prog->get_node("shape_of"), prog->get_node("updated_broadcast"), {{"updated_broadcast", 2}}));
517+
ASSERT_TRUE(check_subgraph(prog->get_node("max_context_len"), prog->get_node("updated_broadcast"), {{"updated_broadcast", 2}, {"paged_attention", 0}}));
518+
}

0 commit comments

Comments
 (0)