@@ -440,6 +440,159 @@ attach_non_max_suppression_impl::attach_non_max_suppression_impl() {
440
440
}
441
441
442
442
} // namespace detail
443
+
444
+ namespace {
445
+
446
+ template <typename T>
447
+ std::vector<T> get_nms_gather_input (stream& stream, memory::ptr mem) {
448
+ auto dep_mem_layout = mem->get_layout ();
449
+ auto dep_mem_batch = static_cast <size_t >(dep_mem_layout.batch ());
450
+
451
+ mem_lock<T, mem_lock_type::read > dep_mem_lock (mem, stream);
452
+ auto dep_mem_ptr = dep_mem_lock.data ();
453
+
454
+ size_t actual_valid_num = dep_mem_batch;
455
+ size_t idx = 0 ;
456
+ for (size_t i = 0 ; i < dep_mem_batch; i++) {
457
+ idx = i * 3 ;
458
+ if (dep_mem_ptr[idx] == -1 ) {
459
+ actual_valid_num = i;
460
+ break ;
461
+ }
462
+ }
463
+
464
+ std::vector<T> result;
465
+ for (size_t i = 0 ; i < actual_valid_num; i++) {
466
+ idx = i * 3 ;
467
+ result.push_back (dep_mem_ptr[idx + 0 ]);
468
+ result.push_back (dep_mem_ptr[idx + 1 ]);
469
+ result.push_back (dep_mem_ptr[idx + 2 ]);
470
+ }
471
+
472
+ return result;
473
+ }
474
+
475
+ template <typename T>
476
+ void store_nms_gather_output0 (stream& stream, memory::ptr mem, std::vector<T> valid_input) {
477
+ auto valid_input_size = valid_input.size () / 3 ;
478
+
479
+ mem_lock<T, mem_lock_type::write > lock (mem, stream);
480
+ auto ptr = lock.data ();
481
+
482
+ auto output_batch = static_cast <size_t >(mem->get_layout ().batch ());
483
+ for (size_t si = 0 ; si < std::min (valid_input_size, output_batch); ++si) {
484
+ auto offset = si * 3 ;
485
+ // batch_index, class_index, box_index
486
+ ptr[offset + 0 ] = static_cast <T>(valid_input[offset + 0 ]);
487
+ ptr[offset + 1 ] = static_cast <T>(valid_input[offset + 1 ]);
488
+ ptr[offset + 2 ] = static_cast <T>(valid_input[offset + 2 ]);
489
+ }
490
+ }
491
+
492
+ template <typename T>
493
+ void store_nms_gather_output1 (stream& stream, memory::ptr mem, std::vector<T> valid_input) {
494
+ auto valid_input_size = valid_input.size () / 3 ;
495
+
496
+ mem_lock<T, mem_lock_type::write > lock (mem, stream);
497
+ auto ptr = lock.data ();
498
+
499
+ auto output_batch = static_cast <size_t >(mem->get_layout ().batch ());
500
+ for (size_t si = 0 ; si < std::min (valid_input_size, output_batch); ++si) {
501
+ auto offset = si * 3 ;
502
+ // batch_index, class_index, score
503
+ ptr[offset + 0 ] = static_cast <T>(valid_input[offset + 0 ]);
504
+ ptr[offset + 1 ] = static_cast <T>(valid_input[offset + 1 ]);
505
+ ptr[offset + 2 ] = static_cast <T>(valid_input[offset + 2 ]);
506
+ }
507
+ }
508
+
509
+ template <typename T>
510
+ void store_nms_gather_output2 (stream& stream, memory::ptr mem, std::vector<int32_t > valid_input) {
511
+ auto valid_input_size = valid_input.size () / 3 ;
512
+
513
+ mem_lock<T, mem_lock_type::write > lock (mem, stream);
514
+ auto ptr = lock.data ();
515
+ ptr[0 ] = static_cast <T>(valid_input_size);
516
+ }
517
+
518
+ void run_nms_gather (non_max_suppression_gather_inst& instance) {
519
+ auto & stream = instance.get_network ().get_stream ();
520
+
521
+ auto valid_input0 = get_nms_gather_input<ov::element_type_traits<data_types::i32>::value_type>(stream, instance.dep_memory_ptr (0 ));
522
+ store_nms_gather_output0<ov::element_type_traits<data_types::i32>::value_type>(stream, instance.output_memory_ptr (0 ), valid_input0);
523
+
524
+ if (instance.outputs_memory_count () >= 2 ) {
525
+ auto data_type = instance.dep_memory_ptr (1 )->get_layout ().data_type ;
526
+
527
+ if (data_type == cldnn::data_types::f16) {
528
+ auto valid_input_f16 = get_nms_gather_input<ov::element_type_traits<data_types::f16>::value_type>(stream, instance.dep_memory_ptr (1 ));
529
+ store_nms_gather_output1<ov::element_type_traits<data_types::f16>::value_type>(stream, instance.output_memory_ptr (1 ), valid_input_f16);
530
+ } else if (data_type == cldnn::data_types::f32) {
531
+ auto valid_input_f32 = get_nms_gather_input<ov::element_type_traits<data_types::f32>::value_type>(stream, instance.dep_memory_ptr (1 ));
532
+ store_nms_gather_output1<ov::element_type_traits<data_types::f32>::value_type>(stream, instance.output_memory_ptr (1 ), valid_input_f32);
533
+ } else {
534
+ throw std::runtime_error (" Non max suppression gather - unsupported second output data type" );
535
+ }
536
+
537
+ if (instance.outputs_memory_count () == 3 ) {
538
+ store_nms_gather_output2<ov::element_type_traits<data_types::i32>::value_type>(stream, instance.output_memory_ptr (2 ), valid_input0);
539
+ }
540
+ }
541
+ }
542
+ } // namespace
543
+ struct non_max_suppression_gather_impl : typed_primitive_impl<non_max_suppression_gather> {
544
+ using parent = typed_primitive_impl<non_max_suppression_gather>;
545
+
546
+ DECLARE_OBJECT_TYPE_SERIALIZATION (cldnn::cpu::non_max_suppression_gather_impl)
547
+
548
+ std::unique_ptr<primitive_impl> clone () const override {
549
+ return make_unique<non_max_suppression_gather_impl>(*this );
550
+ }
551
+
552
+ non_max_suppression_gather_impl () : parent(" non_max_suppression_gather_impl" ) {}
553
+
554
+ event::ptr execute_impl (const std::vector<event::ptr>& events, typed_primitive_inst<non_max_suppression_gather>& instance) override {
555
+ auto & stream = instance.get_network ().get_stream ();
556
+
557
+ const bool pass_through_events = (stream.get_queue_type () == QueueTypes::out_of_order) && instance.get_node ().is_in_shape_of_subgraph ();
558
+
559
+ if (!pass_through_events) {
560
+ for (auto e : events) {
561
+ e->wait ();
562
+ }
563
+ }
564
+
565
+ run_nms_gather (instance);
566
+
567
+ if (pass_through_events) {
568
+ if (events.size () > 1 ) {
569
+ return stream.group_events (events);
570
+ } else if (events.size () == 1 ) {
571
+ return events[0 ];
572
+ }
573
+ }
574
+
575
+ return stream.create_user_event (true );
576
+ }
577
+
578
+ static std::unique_ptr<primitive_impl> create (const non_max_suppression_gather_node&, const kernel_impl_params&) {
579
+ return make_unique<non_max_suppression_gather_impl>();
580
+ }
581
+ void init_kernels (const kernels_cache&, const kernel_impl_params&) override {}
582
+ };
583
+
584
+ namespace detail {
585
+
586
+ attach_non_max_suppression_gather_impl::attach_non_max_suppression_gather_impl () {
587
+ implementation_map<non_max_suppression_gather>::add (impl_types::cpu, non_max_suppression_gather_impl::create, {
588
+ std::make_tuple (data_types::i32, format::bfyx),
589
+ std::make_tuple (data_types::f16, format::bfyx),
590
+ std::make_tuple (data_types::f32, format::bfyx),
591
+ });
592
+ }
593
+
594
+ } // namespace detail
595
+
443
596
} // namespace cpu
444
597
} // namespace cldnn
445
598
0 commit comments