@@ -440,6 +440,121 @@ attach_non_max_suppression_impl::attach_non_max_suppression_impl() {
440
440
}
441
441
442
442
} // namespace detail
443
+
444
+ namespace {
445
+
446
+ template <typename T>
447
+ size_t get_nms_gather_valid_size (stream& stream, memory::ptr mem) {
448
+ auto dep_mem_layout = mem->get_layout ();
449
+ auto dep_mem_batch = static_cast <size_t >(dep_mem_layout.batch ());
450
+
451
+ mem_lock<T, mem_lock_type::read > dep_mem_lock (mem, stream);
452
+ auto dep_mem_ptr = dep_mem_lock.data ();
453
+
454
+ size_t actual_valid_num = dep_mem_batch;
455
+ size_t idx = 0 ;
456
+ for (size_t i = 0 ; i < dep_mem_batch; i++) {
457
+ idx = i * 3 ;
458
+ if (dep_mem_ptr[idx] == -1 ) {
459
+ actual_valid_num = i;
460
+ break ;
461
+ }
462
+ }
463
+
464
+ return actual_valid_num;
465
+ }
466
+
467
+ template <typename T>
468
+ void store_nms_gather_output (non_max_suppression_gather_inst& instance, size_t idx, size_t valid_size) {
469
+ auto input_mem = instance.dep_memory_ptr (idx);
470
+ layout dep_layout = input_mem->get_layout ();
471
+ auto output_ps = dep_layout.get_partial_shape ();
472
+
473
+ output_ps[0 ] = valid_size; // update valid batch size
474
+ auto output_layout = layout (output_ps, dep_layout.data_type , dep_layout.format );
475
+ auto new_output_mem = instance.get_network ().get_engine ().reinterpret_buffer (*input_mem, output_layout);
476
+
477
+ instance.set_output_memory (new_output_mem, true , idx);
478
+ }
479
+
480
+ void run_nms_gather (non_max_suppression_gather_inst& instance) {
481
+ auto & stream = instance.get_network ().get_stream ();
482
+
483
+ auto valid_input_batch = get_nms_gather_valid_size<ov::element_type_traits<data_types::i32>::value_type>(stream, instance.dep_memory_ptr (0 ));
484
+ store_nms_gather_output<ov::element_type_traits<data_types::i32>::value_type>(instance, 0 , valid_input_batch);
485
+
486
+ if (instance.outputs_memory_count () >= 2 ) {
487
+ auto data_type = instance.dep_memory_ptr (1 )->get_layout ().data_type ;
488
+
489
+ if (data_type == cldnn::data_types::f16) {
490
+ store_nms_gather_output<ov::element_type_traits<data_types::f16>::value_type>(instance, 1 , valid_input_batch);
491
+ } else if (data_type == cldnn::data_types::f32) {
492
+ store_nms_gather_output<ov::element_type_traits<data_types::f32>::value_type>(instance, 1 , valid_input_batch);
493
+ } else {
494
+ throw std::runtime_error (" Non max suppression gather - unsupported second output data type" );
495
+ }
496
+
497
+ if (instance.outputs_memory_count () == 3 ) {
498
+ mem_lock<ov::element_type_traits<data_types::i32>::value_type, mem_lock_type::write > lock (instance.output_memory_ptr (2 ), stream);
499
+ auto ptr = lock.data ();
500
+ ptr[0 ] = static_cast <ov::element_type_traits<data_types::i32>::value_type>(valid_input_batch);
501
+ }
502
+ }
503
+ }
504
+ } // namespace
505
+ struct non_max_suppression_gather_impl : typed_primitive_impl<non_max_suppression_gather> {
506
+ using parent = typed_primitive_impl<non_max_suppression_gather>;
507
+
508
+ DECLARE_OBJECT_TYPE_SERIALIZATION (cldnn::cpu::non_max_suppression_gather_impl)
509
+
510
+ std::unique_ptr<primitive_impl> clone () const override {
511
+ return make_unique<non_max_suppression_gather_impl>(*this );
512
+ }
513
+
514
+ non_max_suppression_gather_impl () : parent(" non_max_suppression_gather_impl" ) {}
515
+
516
+ event::ptr execute_impl (const std::vector<event::ptr>& events, typed_primitive_inst<non_max_suppression_gather>& instance) override {
517
+ auto & stream = instance.get_network ().get_stream ();
518
+
519
+ const bool pass_through_events = (stream.get_queue_type () == QueueTypes::out_of_order) && instance.get_node ().is_in_shape_of_subgraph ();
520
+
521
+ if (!pass_through_events) {
522
+ for (auto e : events) {
523
+ e->wait ();
524
+ }
525
+ }
526
+
527
+ run_nms_gather (instance);
528
+
529
+ if (pass_through_events) {
530
+ if (events.size () > 1 ) {
531
+ return stream.group_events (events);
532
+ } else if (events.size () == 1 ) {
533
+ return events[0 ];
534
+ }
535
+ }
536
+
537
+ return stream.create_user_event (true );
538
+ }
539
+
540
+ static std::unique_ptr<primitive_impl> create (const non_max_suppression_gather_node&, const kernel_impl_params&) {
541
+ return make_unique<non_max_suppression_gather_impl>();
542
+ }
543
+ void init_kernels (const kernels_cache&, const kernel_impl_params&) override {}
544
+ };
545
+
546
+ namespace detail {
547
+
548
+ attach_non_max_suppression_gather_impl::attach_non_max_suppression_gather_impl () {
549
+ implementation_map<non_max_suppression_gather>::add (impl_types::cpu, non_max_suppression_gather_impl::create, {
550
+ std::make_tuple (data_types::i32, format::bfyx),
551
+ std::make_tuple (data_types::f16, format::bfyx),
552
+ std::make_tuple (data_types::f32, format::bfyx),
553
+ });
554
+ }
555
+
556
+ } // namespace detail
557
+
443
558
} // namespace cpu
444
559
} // namespace cldnn
445
560
0 commit comments