@@ -442,6 +442,112 @@ attach_non_max_suppression_impl::attach_non_max_suppression_impl() {
442
442
}
443
443
444
444
} // namespace detail
445
+
446
+ namespace {
447
+
448
+ std::vector<int32_t > get_nms_gather_input (stream& stream, memory::ptr mem) {
449
+ auto dep_mem_layout = mem->get_layout ();
450
+ auto dep_mem_batch = static_cast <size_t >(dep_mem_layout.batch ());
451
+
452
+ mem_lock<int32_t , mem_lock_type::read > dep_mem_lock (mem, stream);
453
+ auto dep_mem_ptr = dep_mem_lock.data ();
454
+
455
+ size_t actual_valid_num = dep_mem_batch;
456
+ size_t idx = 0 ;
457
+ for (size_t i = 0 ; i < dep_mem_batch; i++) {
458
+ idx = i * 3 ;
459
+ if (dep_mem_ptr[idx] == -1 ) {
460
+ actual_valid_num = i;
461
+ break ;
462
+ }
463
+ }
464
+
465
+ std::vector<int32_t > result;
466
+ for (size_t i = 0 ; i < actual_valid_num; i++) {
467
+ idx = i * 3 ;
468
+ result.push_back (dep_mem_ptr[idx + 0 ]);
469
+ result.push_back (dep_mem_ptr[idx + 1 ]);
470
+ result.push_back (dep_mem_ptr[idx + 2 ]);
471
+ }
472
+
473
+ return result;
474
+ }
475
+
476
+ void store_nms_gather_output (stream& stream, memory::ptr mem, std::vector<int32_t > valid_input) {
477
+ auto valid_input_size = valid_input.size () / 3 ;
478
+
479
+ mem_lock<int32_t , mem_lock_type::write > lock (mem, stream);
480
+ auto ptr = lock.data ();
481
+
482
+ auto output_batch = static_cast <size_t >(mem->get_layout ().batch ());
483
+ for (size_t si = 0 ; si < std::min (valid_input_size, output_batch); ++si) {
484
+ auto offset = si * 3 ;
485
+ ptr[offset + 0 ] = static_cast <int32_t >(valid_input[offset + 0 ]);
486
+ ptr[offset + 1 ] = static_cast <int32_t >(valid_input[offset + 1 ]);
487
+ ptr[offset + 2 ] = static_cast <int32_t >(valid_input[offset + 2 ]);
488
+ }
489
+ }
490
+
491
+ void run_nms_gather (non_max_suppression_gather_inst& instance) {
492
+ auto & stream = instance.get_network ().get_stream ();
493
+
494
+ auto valid_input = get_nms_gather_input (stream, instance.dep_memory_ptr (0 ));
495
+ store_nms_gather_output (stream, instance.output_memory_ptr (), valid_input);
496
+ }
497
+ }
498
+ struct non_max_suppression_gather_impl : typed_primitive_impl<non_max_suppression_gather> {
499
+ using parent = typed_primitive_impl<non_max_suppression_gather>;
500
+
501
+ DECLARE_OBJECT_TYPE_SERIALIZATION (cldnn::cpu::non_max_suppression_gather_impl)
502
+
503
+ std::unique_ptr<primitive_impl> clone () const override {
504
+ return make_unique<non_max_suppression_gather_impl>(*this );
505
+ }
506
+
507
+ non_max_suppression_gather_impl () : parent(" non_max_suppression_gather_impl" ) {}
508
+
509
+ event::ptr execute_impl (const std::vector<event::ptr>& events, typed_primitive_inst<non_max_suppression_gather>& instance) override {
510
+ auto & stream = instance.get_network ().get_stream ();
511
+
512
+ const bool pass_through_events = (stream.get_queue_type () == QueueTypes::out_of_order) && instance.get_node ().is_in_shape_of_subgraph ();
513
+
514
+ if (!pass_through_events) {
515
+ for (auto e : events) {
516
+ e->wait ();
517
+ }
518
+ }
519
+
520
+ run_nms_gather (instance);
521
+
522
+ if (pass_through_events) {
523
+ if (events.size () > 1 ) {
524
+ return stream.group_events (events);
525
+ } else if (events.size () == 1 ) {
526
+ return events[0 ];
527
+ }
528
+ }
529
+
530
+ return stream.create_user_event (true );
531
+ }
532
+
533
+ static std::unique_ptr<primitive_impl> create (const non_max_suppression_gather_node&, const kernel_impl_params&) {
534
+ return make_unique<non_max_suppression_gather_impl>();
535
+ }
536
+ void init_kernels (const kernels_cache&, const kernel_impl_params&) override {}
537
+ };
538
+
539
+ namespace detail {
540
+
541
+ attach_non_max_suppression_gather_impl::attach_non_max_suppression_gather_impl () {
542
+ implementation_map<non_max_suppression_gather>::add (impl_types::cpu, non_max_suppression_gather_impl::create, {
543
+ std::make_tuple (data_types::i32, format::bfyx),
544
+ std::make_tuple (data_types::f16, format::bfyx),
545
+ std::make_tuple (data_types::f32, format::bfyx),
546
+ });
547
+ }
548
+
549
+ } // namespace detail
550
+
445
551
} // namespace cpu
446
552
} // namespace cldnn
447
553
0 commit comments