@@ -601,3 +601,170 @@ TEST(loop_gpu, support_dynamic_tensoriterator_outer_axis) {
601
601
602
602
test_loop_gpu_wo_trip_count ({ 2 , 1 , 1 , 2 }, { 2 , 5 , 1 , 2 }, input_data_5_4, output_data_5_4, 1 , 4 );
603
603
}
604
+
605
+ static void test_loop_gpu_wo_trip_count_w_multiple_shapes (ov::PartialShape body_input_layout,
606
+ std::vector<ov::PartialShape> whole_layouts,
607
+ std::vector<std::vector<float >> input_data_list,
608
+ std::vector<float > expected_output_data,
609
+ size_t axis,
610
+ size_t exit_value,
611
+ bool is_caching_test = false ) {
612
+ auto & engine = get_test_engine ();
613
+
614
+ auto b_input_layout = cldnn::layout{ body_input_layout, data_types::f32, format::bfyx };
615
+
616
+ ov::PartialShape sliced_input_shape = body_input_layout;
617
+ sliced_input_shape[axis] = 1 ;
618
+ auto sliced_input_layout = cldnn::layout{ sliced_input_shape, data_types::f32, format::bfyx };
619
+
620
+ auto const_layout = cldnn::layout{ {}, data_types::i64, format::bfyx };
621
+
622
+ auto e_initial_condition_mem = engine.allocate_memory (const_layout);
623
+ auto e_num_iteration_mem = engine.allocate_memory (const_layout);
624
+ auto b_exit_value_mem = engine.allocate_memory (const_layout);
625
+ auto b_index_inc_mem = engine.allocate_memory (const_layout);
626
+
627
+ // initialize input buffers
628
+ set_values (e_initial_condition_mem, {1 });
629
+ set_values (b_exit_value_mem, {exit_value});
630
+ set_values (b_index_inc_mem, {1 });
631
+ set_values (e_num_iteration_mem, {0 });
632
+
633
+ primitive_id body_current_iteration_id = " b_index" ;
634
+ primitive_id body_execution_condition_id = " b_cond_exit_value" ;
635
+
636
+ cldnn::topology body (
637
+ input_layout (body_current_iteration_id, const_layout),
638
+ input_layout (" b_add_data" , sliced_input_layout),
639
+ input_layout (" b_mul_data" , sliced_input_layout),
640
+ data (" b_exit_value" , b_exit_value_mem),
641
+ data (" b_index_inc" , b_index_inc_mem),
642
+ eltwise (" b_index_update" , input_info (body_current_iteration_id), input_info (" b_index_inc" ), eltwise_mode::sum),
643
+ reorder (" b_index_cast" , input_info (" b_index_update" ),
644
+ cldnn::format::any, data_types::f32, {}, cldnn::reorder_mean_mode::subtract, cldnn::padding (), true ),
645
+ eltwise (body_execution_condition_id, input_info (" b_index" ), input_info (" b_exit_value" ), eltwise_mode::lt),
646
+ eltwise (" b_add" , input_info (" b_add_data" ), input_info (" b_index_cast" ), eltwise_mode::sum),
647
+ eltwise (" b_mul" , input_info (" b_mul_data" ), input_info (" b_index_cast" ), eltwise_mode::prod));
648
+
649
+ primitive_id trip_count_id = " " ;
650
+ primitive_id actual_iteration_count_id = " actual_iteration_count" ;
651
+ primitive_id initial_condition_id = " initial_condition" ;
652
+ int64_t num_iterations = -1 ;
653
+
654
+ std::vector<loop::io_primitive_map> input_primitive_maps {
655
+ loop::io_primitive_map (" input" , " b_add_data" , axis),
656
+ loop::io_primitive_map (" input" , " b_mul_data" , axis),
657
+ loop::io_primitive_map (actual_iteration_count_id, body_current_iteration_id) };
658
+ std::vector<loop::io_primitive_map> output_primitive_maps {
659
+ loop::io_primitive_map (cldnn::input_info (" loop" , 0 ), cldnn::input_info (" b_add" , 0 ), axis),
660
+ loop::io_primitive_map (cldnn::input_info (" loop" , 1 ), cldnn::input_info (" b_mul" , 0 ), axis) };
661
+ std::vector<loop::backedge_mapping> back_edges {
662
+ loop::backedge_mapping (" b_index_update" , body_current_iteration_id) };
663
+
664
+ auto body_program = build_program (engine, body, body_execution_condition_id, output_primitive_maps, back_edges, true );
665
+
666
+ cldnn::topology topology (
667
+ input_layout (" input" , b_input_layout),
668
+ input_layout (initial_condition_id, e_initial_condition_mem->get_layout ()),
669
+ mutable_data (actual_iteration_count_id, e_num_iteration_mem),
670
+ loop (" loop" , { input_info (actual_iteration_count_id), input_info (initial_condition_id), input_info (" input" ) }, body_program,
671
+ trip_count_id, initial_condition_id, actual_iteration_count_id,
672
+ input_primitive_maps, output_primitive_maps, back_edges,
673
+ num_iterations, body_current_iteration_id, body_execution_condition_id, 2 ),
674
+ eltwise (" out_sum" , input_info (" loop" , 0 ), input_info (" loop" , 1 ), eltwise_mode::sum));
675
+
676
+ ExecutionConfig config = get_test_default_config (engine);
677
+ config.set_property (ov::intel_gpu::allow_new_shape_infer (true ));
678
+
679
+ cldnn::network::ptr network = get_network (engine, topology, config, get_test_stream_ptr (), is_caching_test);
680
+
681
+
682
+ for (size_t i = 0 ; i < whole_layouts.size (); i++) {
683
+ auto whole_layout = whole_layouts[i];
684
+ auto input_data = input_data_list[i];
685
+
686
+ // initialize input buffers
687
+ set_values (e_initial_condition_mem, {1 });
688
+ set_values (b_exit_value_mem, {exit_value});
689
+ set_values (b_index_inc_mem, {1 });
690
+ set_values (e_num_iteration_mem, {0 });
691
+
692
+ auto e_input_layout = cldnn::layout{ whole_layout, data_types::f32, format::bfyx };
693
+ auto e_input_mem = engine.allocate_memory (e_input_layout); // b,f,x,y
694
+ auto expected_output_layout = whole_layout;
695
+ set_values (e_input_mem, input_data);
696
+ network->set_input_data (" input" , e_input_mem);
697
+
698
+ network->set_input_data (initial_condition_id, e_initial_condition_mem);
699
+
700
+ auto outputs = network->execute ();
701
+ ASSERT_EQ (outputs.size (), 1 );
702
+
703
+ auto expected_num_iterations = (exit_value + 1 );
704
+ expected_output_layout[axis] = expected_num_iterations;
705
+ auto e_output_layout = cldnn::layout{ expected_output_layout, data_types::f32, format::bfyx };
706
+
707
+ auto num_iter_mem = network->get_output_memory (actual_iteration_count_id);
708
+ if (num_iter_mem != nullptr ) {
709
+ mem_lock<int64_t > num_iter_ptr{ num_iter_mem, get_test_stream () };
710
+ ASSERT_EQ (num_iter_ptr.data ()[0 ], expected_num_iterations);
711
+ }
712
+
713
+ std::vector<float > expected (input_data.size ());
714
+ if (expected_output_data.size () == 0 ) {
715
+ size_t unit = 1 ;
716
+ for (size_t k = axis; k < whole_layout.size (); k++) {
717
+ unit *= whole_layout[k].get_length ();
718
+ }
719
+
720
+ for (size_t j = 0 ; j < input_data.size (); j++) {
721
+ auto val = static_cast <size_t >((j % unit) / 4 ) + 1 ;
722
+ expected[j] = static_cast <float >(input_data[j] + val) + static_cast <float >(input_data[j] * val);
723
+ }
724
+ } else {
725
+ expected = expected_output_data;
726
+ }
727
+
728
+ auto output_mem = outputs.begin ()->second .get_memory ();
729
+ auto output_layout = output_mem->get_layout ();
730
+ ASSERT_EQ (output_layout.batch (), e_output_layout.batch ());
731
+ ASSERT_EQ (output_layout.feature (), e_output_layout.feature ());
732
+ ASSERT_EQ (output_layout.spatial (0 ), e_output_layout.spatial (0 ));
733
+ ASSERT_EQ (output_layout.spatial (1 ), e_output_layout.spatial (1 ));
734
+ // value check
735
+ {
736
+ mem_lock<float > output_ptr{ output_mem, get_test_stream () };
737
+ for (size_t i = 0 , iend = output_layout.count (); i < iend; ++i) {
738
+ ASSERT_FLOAT_EQ (output_ptr[i], expected.at (i));
739
+ }
740
+ }
741
+ }
742
+ }
743
+
744
+ std::vector<float > input_data_4_4{
745
+ 1 .0f , 2 .0f , -15 .f , 3 .0f ,
746
+ 4 .0f , -15 .f , 5 .0f , 6 .0f ,
747
+ -15 .f , 7 .0f , -15 .f , 0 .0f ,
748
+ 0 .0f , -15 .f , 0 .5f , -0 .5f ,
749
+ };
750
+
751
+ std::vector<float > input_data_2_4_4{
752
+ 1 .0f , 2 .0f , -15 .f , 3 .0f ,
753
+ 4 .0f , -15 .f , 5 .0f , 6 .0f ,
754
+ -15 .f , 7 .0f , -15 .f , 0 .0f ,
755
+ 0 .0f , -15 .f , 0 .5f , -0 .5f ,
756
+
757
+ 1 .0f , 2 .0f , -15 .f , 3 .0f ,
758
+ 4 .0f , -15 .f , 5 .0f , 6 .0f ,
759
+ -15 .f , 7 .0f , -15 .f , 0 .0f ,
760
+ 0 .0f , -15 .f , 0 .5f , -0 .5f ,
761
+ };
762
+
763
+ TEST (loop_gpu, support_loop_w_dynamic_input_w_various_shapes) {
764
+ test_loop_gpu_wo_trip_count_w_multiple_shapes (
765
+ { 1 , -1 , 4 , 4 },
766
+ {{ 1 , 1 , 4 , 4 }, { 1 , 2 , 4 , 4 }}, // axis value should be iter_num = (exit_value + 1)
767
+ {input_data_4_4, input_data_2_4_4},
768
+ std::vector<float >(),
769
+ 2 , 3 );
770
+ }
0 commit comments