|
11 | 11 | #include "nms_shape_inference.hpp"
|
12 | 12 |
|
13 | 13 | namespace cldnn {
|
| 14 | + |
| 15 | +// ----------------------------------------------- |
| 16 | +// non_max_suppression |
| 17 | +// ----------------------------------------------- |
14 | 18 | GPU_DEFINE_PRIMITIVE_TYPE_ID(non_max_suppression)
|
15 | 19 |
|
16 | 20 | layout non_max_suppression_inst::calc_output_layout(non_max_suppression_node const& node, kernel_impl_params const& impl_param) {
|
@@ -81,4 +85,79 @@ std::string non_max_suppression_inst::to_string(non_max_suppression_node const&
|
81 | 85 | return description.str();
|
82 | 86 | }
|
83 | 87 |
|
| 88 | +// ----------------------------------------------- |
| 89 | +// non_max_suppression_gather |
| 90 | +// ----------------------------------------------- |
| 91 | +GPU_DEFINE_PRIMITIVE_TYPE_ID(non_max_suppression_gather) |
| 92 | + |
| 93 | +layout non_max_suppression_gather_inst::calc_output_layout(non_max_suppression_gather_node const& node, kernel_impl_params const& impl_param) { |
| 94 | + OPENVINO_THROW("Only calc_output_layouts should be used!"); |
| 95 | +} |
| 96 | + |
| 97 | +template<typename ShapeType> |
| 98 | +std::vector<layout> non_max_suppression_gather_inst::calc_output_layouts(non_max_suppression_gather_node const& /*node*/, |
| 99 | + const kernel_impl_params& impl_param) { |
| 100 | + std::vector<layout> layouts; |
| 101 | + |
| 102 | + auto desc = impl_param.typed_desc<non_max_suppression_gather>(); |
| 103 | + std::vector<ShapeType> output_shapes = { ShapeType{}, ShapeType{}, ShapeType{} }; |
| 104 | + |
| 105 | + auto& memory_deps = impl_param.memory_deps; |
| 106 | + if (memory_deps.count(2)) { |
| 107 | + auto third_output = memory_deps.at(2); |
| 108 | + cldnn::mem_lock<int32_t, mem_lock_type::read> third_output_lock(third_output, impl_param.get_stream()); |
| 109 | + auto third_output_data = third_output_lock.data(); |
| 110 | + |
| 111 | + output_shapes[0] = ShapeType{third_output_data[0], 3}; |
| 112 | + } else { |
| 113 | + output_shapes[0] = ShapeType{ov::Dimension::dynamic(), 3}; |
| 114 | + } |
| 115 | + output_shapes[1] = output_shapes[0]; |
| 116 | + output_shapes[2] = ShapeType{1}; |
| 117 | + |
| 118 | + for (size_t i = 0; i < desc->num_outputs; ++i) { |
| 119 | + layouts.push_back({output_shapes[i], |
| 120 | + impl_param.get_input_layout(i).data_type, |
| 121 | + format::get_default_format(output_shapes[i].size())}); |
| 122 | + } |
| 123 | + return layouts; |
| 124 | +} |
| 125 | + |
| 126 | +template std::vector<layout> non_max_suppression_gather_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_gather_node const& node, |
| 127 | + const kernel_impl_params& impl_param); |
| 128 | + |
| 129 | +std::string non_max_suppression_gather_inst::to_string(non_max_suppression_gather_node const& node) { |
| 130 | + auto desc = node.get_primitive(); |
| 131 | + auto node_info = node.desc_to_json(); |
| 132 | + |
| 133 | + json_composite info; |
| 134 | + |
| 135 | + node_info->add("non max suppression gather info", info); |
| 136 | + |
| 137 | + std::stringstream description; |
| 138 | + node_info->dump(description); |
| 139 | + return description.str(); |
| 140 | +} |
| 141 | + |
| 142 | +void non_max_suppression_gather_inst::on_execute() { |
| 143 | + update_output_memory(); |
| 144 | +} |
| 145 | + |
| 146 | +void non_max_suppression_gather_inst::update_output_memory() { |
| 147 | + if (!can_be_optimized()) |
| 148 | + return; |
| 149 | + |
| 150 | + for (size_t i = 0; i < inputs_memory_count(); i++) { |
| 151 | + if (node->get_program().is_new_shape_infer() && input_memory_ptr(i) == nullptr) |
| 152 | + return; |
| 153 | + |
| 154 | + if (output_memory_ptr(i) != nullptr && _network.get_engine().is_the_same_buffer(output_memory(i), input_memory(i))) |
| 155 | + return; |
| 156 | + |
| 157 | + _outputs[i] = {_network.get_engine().reinterpret_buffer(input_memory(i), _impl_params->get_output_layout(i))}; |
| 158 | + } |
| 159 | +} |
| 160 | + |
| 161 | +non_max_suppression_gather_inst::typed_primitive_inst(network& network, non_max_suppression_gather_node const& node) : parent(network, node) {} |
| 162 | + |
84 | 163 | } // namespace cldnn
|
0 commit comments