18
18
#include " select_inst.h"
19
19
#include " strided_slice_inst.h"
20
20
#include " broadcast_inst.h"
21
+ #include " paged_attention_inst.h"
21
22
#include " pass_manager.h"
22
23
#include " to_string_utils.h"
23
24
@@ -31,6 +32,10 @@ static bool check_subgraph(const program_node& node, const program_node& last_no
31
32
if (custom_dependant_nodes_count.find (node.id ()) != custom_dependant_nodes_count.end ())
32
33
expected_dependant_nodes = custom_dependant_nodes_count[node.id ()];
33
34
35
+ // Skip some custom nodes if they are not intended to be included into shape_of subgraph
36
+ if (expected_dependant_nodes == 0 )
37
+ return true ;
38
+
34
39
if (!node.is_in_shape_of_subgraph () || node.get_dependant_shape_of_nodes ().size () != expected_dependant_nodes)
35
40
return false ;
36
41
@@ -423,3 +428,91 @@ TEST(mark_shape_of_subgraphs, broadcast_w_direct_shapeof_and_data) {
423
428
424
429
ASSERT_TRUE (check_subgraph (prog->get_node (" shape_of" ), prog->get_node (" broadcast" )));
425
430
}
431
+
432
+ TEST (mark_shape_of_subgraphs, paged_attention_max_context_len_input) {
433
+ auto & engine = get_test_engine ();
434
+ auto input_layout_dynamic = layout{ov::PartialShape{ov::Dimension::dynamic (), 4 , ov::Dimension::dynamic (), ov::Dimension::dynamic ()},
435
+ data_types::f32, format::bfyx};
436
+ auto target_shape = engine.allocate_memory ({ ov::PartialShape{4 }, data_types::i32, format::bfyx });
437
+ set_values (target_shape, {4 , 4 , 1 , 1 });
438
+
439
+ auto subtract_one = engine.allocate_memory ({ ov::PartialShape{1 }, data_types::i32, format::bfyx });
440
+ set_values (target_shape, {-1 });
441
+
442
+ auto query_layout = layout{ov::PartialShape{ov::Dimension::dynamic (), 128 },
443
+ data_types::f32,
444
+ format::bfyx};
445
+ auto key_layout = query_layout;
446
+ auto value_layout = query_layout;
447
+ auto key_cache_layout = layout{ov::PartialShape{ov::Dimension::dynamic (), 2 , 64 , 16 },
448
+ data_types::f32,
449
+ format::bfyx};
450
+ auto dynamic_i32_layout = layout{ov::PartialShape::dynamic (1 ), data_types::i32, format::bfyx};
451
+ auto value_cache_layout = key_cache_layout;
452
+ auto past_lens_layout = dynamic_i32_layout;
453
+ auto subsequence_begins_layout = dynamic_i32_layout;
454
+ auto block_indices_layout = dynamic_i32_layout;
455
+ auto block_indices_begins_layout = dynamic_i32_layout;
456
+ auto scale_layout = layout{ov::PartialShape{1 }, data_types::f32, format::bfyx};
457
+ auto sliding_window_layout = layout{ov::PartialShape{}, data_types::i32, format::bfyx};
458
+ auto alibi_layout = layout{ov::PartialShape{}, data_types::f32, format::bfyx};
459
+ auto max_context_len_layout = layout{ov::PartialShape{1 }, data_types::i32, format::bfyx};;
460
+
461
+ std::vector<input_info> pa_inputs = {input_info (" query" ),
462
+ input_info (" key" ),
463
+ input_info (" value" ),
464
+ input_info (" key_cache" ),
465
+ input_info (" value_cache" ),
466
+ input_info (" past_lens" ),
467
+ input_info (" subsequence_begins" ),
468
+ input_info (" block_indices" ),
469
+ input_info (" block_indices_begins" ),
470
+ input_info (" scale" ),
471
+ input_info (" sliding_window" ),
472
+ input_info (" alibi" ),
473
+ input_info (" max_context_len" )};
474
+
475
+ auto pa_prim = paged_attention (" paged_attention" , pa_inputs);
476
+ pa_prim.head_size = 64 ;
477
+ pa_prim.kv_heads_num = 2 ;
478
+ pa_prim.heads_num = 2 ;
479
+ pa_prim.scale_val = 1 .f ;
480
+ pa_prim.has_alibi = false ;
481
+ pa_prim.num_outputs = 1 ;
482
+ pa_prim.has_rotated_blocks = false ;
483
+
484
+ topology topology;
485
+ topology.add (input_layout (" query" , query_layout));
486
+ topology.add (input_layout (" key" , key_layout));
487
+ topology.add (input_layout (" value" , value_layout));
488
+ topology.add (input_layout (" key_cache" , key_cache_layout));
489
+ topology.add (input_layout (" value_cache" , value_cache_layout));
490
+ topology.add (input_layout (" past_lens" , past_lens_layout));
491
+ topology.add (input_layout (" subsequence_begins" , subsequence_begins_layout));
492
+ topology.add (input_layout (" block_indices" , block_indices_layout));
493
+ topology.add (input_layout (" block_indices_begins" , block_indices_begins_layout));
494
+ topology.add (input_layout (" scale" , scale_layout));
495
+ topology.add (input_layout (" sliding_window" , sliding_window_layout));
496
+ topology.add (input_layout (" alibi" , alibi_layout));
497
+ topology.add (input_layout (" max_context_len" , max_context_len_layout));
498
+ topology.add (input_layout (" input" , input_layout_dynamic));
499
+ topology.add (data (" target_shape" , target_shape));
500
+ topology.add (data (" subtract_one" , subtract_one));
501
+ topology.add (shape_of (" shape_of" , input_info (" input" ), data_types::i32));
502
+ topology.add (broadcast (" broadcast" , input_info (" shape_of" ), input_info (" target_shape" ), {}, ov::op::BroadcastType::BIDIRECTIONAL));
503
+ topology.add (eltwise (" subtract_one_max_context_len" , input_info (" max_context_len" ), input_info (" subtract_one" ), eltwise_mode::sum));
504
+ topology.add (eltwise (" updated_broadcast" , input_info (" broadcast" ), input_info (" subtract_one_max_context_len" ), eltwise_mode::sum));
505
+ topology.add (reshape (" reshape" , input_info (" input" ), input_info (" updated_broadcast" ), false , ov::PartialShape::dynamic (4 )));
506
+ topology.add (pa_prim);
507
+
508
+ ExecutionConfig config = get_test_default_config (engine);
509
+ config.set_property (ov::intel_gpu::allow_new_shape_infer (true ));
510
+ config.set_property (ov::intel_gpu::optimize_data (true ));
511
+ network network (engine, topology, config);
512
+
513
+ auto prog = network.get_program ();
514
+ ASSERT_NE (prog, nullptr );
515
+
516
+ ASSERT_TRUE (check_subgraph (prog->get_node (" shape_of" ), prog->get_node (" updated_broadcast" ), {{" updated_broadcast" , 2 }}));
517
+ ASSERT_TRUE (check_subgraph (prog->get_node (" max_context_len" ), prog->get_node (" updated_broadcast" ), {{" updated_broadcast" , 2 }, {" paged_attention" , 0 }}));
518
+ }
0 commit comments