@@ -440,6 +440,112 @@ attach_non_max_suppression_impl::attach_non_max_suppression_impl() {
440
440
}
441
441
442
442
} // namespace detail
443
+
444
+ namespace {
445
+
446
+ std::vector<int32_t > get_nms_gather_input (stream& stream, memory::ptr mem) {
447
+ auto dep_mem_layout = mem->get_layout ();
448
+ auto dep_mem_batch = static_cast <size_t >(dep_mem_layout.batch ());
449
+
450
+ mem_lock<int32_t , mem_lock_type::read > dep_mem_lock (mem, stream);
451
+ auto dep_mem_ptr = dep_mem_lock.data ();
452
+
453
+ size_t actual_valid_num = dep_mem_batch;
454
+ size_t idx = 0 ;
455
+ for (size_t i = 0 ; i < dep_mem_batch; i++) {
456
+ idx = i * 3 ;
457
+ if (dep_mem_ptr[idx] == -1 ) {
458
+ actual_valid_num = i;
459
+ break ;
460
+ }
461
+ }
462
+
463
+ std::vector<int32_t > result;
464
+ for (size_t i = 0 ; i < actual_valid_num; i++) {
465
+ idx = i * 3 ;
466
+ result.push_back (dep_mem_ptr[idx + 0 ]);
467
+ result.push_back (dep_mem_ptr[idx + 1 ]);
468
+ result.push_back (dep_mem_ptr[idx + 2 ]);
469
+ }
470
+
471
+ return result;
472
+ }
473
+
474
+ void store_nms_gather_output (stream& stream, memory::ptr mem, std::vector<int32_t > valid_input) {
475
+ auto valid_input_size = valid_input.size () / 3 ;
476
+
477
+ mem_lock<int32_t , mem_lock_type::write > lock (mem, stream);
478
+ auto ptr = lock.data ();
479
+
480
+ auto output_batch = static_cast <size_t >(mem->get_layout ().batch ());
481
+ for (size_t si = 0 ; si < std::min (valid_input_size, output_batch); ++si) {
482
+ auto offset = si * 3 ;
483
+ ptr[offset + 0 ] = static_cast <int32_t >(valid_input[offset + 0 ]);
484
+ ptr[offset + 1 ] = static_cast <int32_t >(valid_input[offset + 1 ]);
485
+ ptr[offset + 2 ] = static_cast <int32_t >(valid_input[offset + 2 ]);
486
+ }
487
+ }
488
+
489
+ void run_nms_gather (non_max_suppression_gather_inst& instance) {
490
+ auto & stream = instance.get_network ().get_stream ();
491
+
492
+ auto valid_input = get_nms_gather_input (stream, instance.dep_memory_ptr (0 ));
493
+ store_nms_gather_output (stream, instance.output_memory_ptr (), valid_input);
494
+ }
495
+ }
496
+ struct non_max_suppression_gather_impl : typed_primitive_impl<non_max_suppression_gather> {
497
+ using parent = typed_primitive_impl<non_max_suppression_gather>;
498
+
499
+ DECLARE_OBJECT_TYPE_SERIALIZATION (cldnn::cpu::non_max_suppression_gather_impl)
500
+
501
+ std::unique_ptr<primitive_impl> clone () const override {
502
+ return make_unique<non_max_suppression_gather_impl>(*this );
503
+ }
504
+
505
+ non_max_suppression_gather_impl () : parent(" non_max_suppression_gather_impl" ) {}
506
+
507
+ event::ptr execute_impl (const std::vector<event::ptr>& events, typed_primitive_inst<non_max_suppression_gather>& instance) override {
508
+ auto & stream = instance.get_network ().get_stream ();
509
+
510
+ const bool pass_through_events = (stream.get_queue_type () == QueueTypes::out_of_order) && instance.get_node ().is_in_shape_of_subgraph ();
511
+
512
+ if (!pass_through_events) {
513
+ for (auto e : events) {
514
+ e->wait ();
515
+ }
516
+ }
517
+
518
+ run_nms_gather (instance);
519
+
520
+ if (pass_through_events) {
521
+ if (events.size () > 1 ) {
522
+ return stream.group_events (events);
523
+ } else if (events.size () == 1 ) {
524
+ return events[0 ];
525
+ }
526
+ }
527
+
528
+ return stream.create_user_event (true );
529
+ }
530
+
531
+ static std::unique_ptr<primitive_impl> create (const non_max_suppression_gather_node&, const kernel_impl_params&) {
532
+ return make_unique<non_max_suppression_gather_impl>();
533
+ }
534
+ void init_kernels (const kernels_cache&, const kernel_impl_params&) override {}
535
+ };
536
+
537
+ namespace detail {
538
+
539
+ attach_non_max_suppression_gather_impl::attach_non_max_suppression_gather_impl () {
540
+ implementation_map<non_max_suppression_gather>::add (impl_types::cpu, non_max_suppression_gather_impl::create, {
541
+ std::make_tuple (data_types::i32, format::bfyx),
542
+ std::make_tuple (data_types::f16, format::bfyx),
543
+ std::make_tuple (data_types::f32, format::bfyx),
544
+ });
545
+ }
546
+
547
+ } // namespace detail
548
+
443
549
} // namespace cpu
444
550
} // namespace cldnn
445
551
0 commit comments