@@ -440,6 +440,132 @@ attach_non_max_suppression_impl::attach_non_max_suppression_impl() {
440
440
}
441
441
442
442
} // namespace detail
443
+
444
+ namespace {
445
+
446
+ std::vector<int32_t > get_nms_gather_input (stream& stream, memory::ptr mem) {
447
+ auto dep_mem_layout = mem->get_layout ();
448
+ auto dep_mem_batch = static_cast <size_t >(dep_mem_layout.batch ());
449
+
450
+ mem_lock<int32_t , mem_lock_type::read > dep_mem_lock (mem, stream);
451
+ auto dep_mem_ptr = dep_mem_lock.data ();
452
+
453
+ size_t actual_valid_num = dep_mem_batch;
454
+ size_t idx = 0 ;
455
+ for (size_t i = 0 ; i < dep_mem_batch; i++) {
456
+ idx = i * 3 ;
457
+ if (dep_mem_ptr[idx] == -1 ) {
458
+ actual_valid_num = i;
459
+ break ;
460
+ }
461
+ }
462
+
463
+ std::vector<int32_t > result;
464
+ for (size_t i = 0 ; i < actual_valid_num; i++) {
465
+ idx = i * 3 ;
466
+ result.push_back (dep_mem_ptr[idx + 0 ]);
467
+ result.push_back (dep_mem_ptr[idx + 1 ]);
468
+ result.push_back (dep_mem_ptr[idx + 2 ]);
469
+ }
470
+
471
+ return result;
472
+ }
473
+
474
+ void store_nms_gather_output (stream& stream, memory::ptr mem, std::vector<int32_t > valid_input) {
475
+ auto valid_input_size = valid_input.size () / 3 ;
476
+
477
+ mem_lock<int32_t , mem_lock_type::write > lock (mem, stream);
478
+ auto ptr = lock.data ();
479
+
480
+ auto output_batch = static_cast <size_t >(mem->get_layout ().batch ());
481
+ for (size_t si = 0 ; si < std::min (valid_input_size, output_batch); ++si) {
482
+ auto offset = si * 3 ;
483
+ ptr[offset + 0 ] = static_cast <int32_t >(valid_input[offset + 0 ]);
484
+ ptr[offset + 1 ] = static_cast <int32_t >(valid_input[offset + 1 ]);
485
+ ptr[offset + 2 ] = static_cast <int32_t >(valid_input[offset + 2 ]);
486
+ }
487
+ }
488
+
489
+
490
+ void store_nms_gather_output2 (stream& stream, memory::ptr mem, std::vector<int32_t > valid_input) {
491
+ auto valid_input_size = valid_input.size () / 3 ;
492
+
493
+ mem_lock<int32_t , mem_lock_type::write > lock (mem, stream);
494
+ auto ptr = lock.data ();
495
+ ptr[0 ] = static_cast <int32_t >(valid_input_size);
496
+ }
497
+
498
+ void run_nms_gather (non_max_suppression_gather_inst& instance) {
499
+ auto & stream = instance.get_network ().get_stream ();
500
+
501
+ // batch_index, class_index, box_index
502
+ auto valid_input0 = get_nms_gather_input (stream, instance.dep_memory_ptr (0 ));
503
+ store_nms_gather_output (stream, instance.output_memory_ptr (), valid_input0);
504
+
505
+ if (instance.outputs_memory_count () >= 2 ) {
506
+ // batch_index, class_index, score
507
+ auto valid_input1 = get_nms_gather_input (stream, instance.dep_memory_ptr (1 ));
508
+ store_nms_gather_output (stream, instance.output_memory_ptr (), valid_input1);
509
+
510
+ if (instance.outputs_memory_count () == 3 ) {
511
+ store_nms_gather_output2 (stream, instance.output_memory_ptr (), valid_input1);
512
+ }
513
+ }
514
+ }
515
+ } // namespace
516
+ struct non_max_suppression_gather_impl : typed_primitive_impl<non_max_suppression_gather> {
517
+ using parent = typed_primitive_impl<non_max_suppression_gather>;
518
+
519
+ DECLARE_OBJECT_TYPE_SERIALIZATION (cldnn::cpu::non_max_suppression_gather_impl)
520
+
521
+ std::unique_ptr<primitive_impl> clone () const override {
522
+ return make_unique<non_max_suppression_gather_impl>(*this );
523
+ }
524
+
525
+ non_max_suppression_gather_impl () : parent(" non_max_suppression_gather_impl" ) {}
526
+
527
+ event::ptr execute_impl (const std::vector<event::ptr>& events, typed_primitive_inst<non_max_suppression_gather>& instance) override {
528
+ auto & stream = instance.get_network ().get_stream ();
529
+
530
+ const bool pass_through_events = (stream.get_queue_type () == QueueTypes::out_of_order) && instance.get_node ().is_in_shape_of_subgraph ();
531
+
532
+ if (!pass_through_events) {
533
+ for (auto e : events) {
534
+ e->wait ();
535
+ }
536
+ }
537
+
538
+ run_nms_gather (instance);
539
+
540
+ if (pass_through_events) {
541
+ if (events.size () > 1 ) {
542
+ return stream.group_events (events);
543
+ } else if (events.size () == 1 ) {
544
+ return events[0 ];
545
+ }
546
+ }
547
+
548
+ return stream.create_user_event (true );
549
+ }
550
+
551
+ static std::unique_ptr<primitive_impl> create (const non_max_suppression_gather_node&, const kernel_impl_params&) {
552
+ return make_unique<non_max_suppression_gather_impl>();
553
+ }
554
+ void init_kernels (const kernels_cache&, const kernel_impl_params&) override {}
555
+ };
556
+
557
+ namespace detail {
558
+
559
+ attach_non_max_suppression_gather_impl::attach_non_max_suppression_gather_impl () {
560
+ implementation_map<non_max_suppression_gather>::add (impl_types::cpu, non_max_suppression_gather_impl::create, {
561
+ std::make_tuple (data_types::i32, format::bfyx),
562
+ std::make_tuple (data_types::f16, format::bfyx),
563
+ std::make_tuple (data_types::f32, format::bfyx),
564
+ });
565
+ }
566
+
567
+ } // namespace detail
568
+
443
569
} // namespace cpu
444
570
} // namespace cldnn
445
571
0 commit comments