|
11 | 11 | #include "nms_shape_inference.hpp"
|
12 | 12 |
|
13 | 13 | namespace cldnn {
|
| 14 | + |
| 15 | +// ----------------------------------------------- |
| 16 | +// non_max_suppression |
| 17 | +// ----------------------------------------------- |
14 | 18 | GPU_DEFINE_PRIMITIVE_TYPE_ID(non_max_suppression)
|
15 | 19 |
|
16 | 20 | layout non_max_suppression_inst::calc_output_layout(non_max_suppression_node const& node, kernel_impl_params const& impl_param) {
|
@@ -81,4 +85,92 @@ std::string non_max_suppression_inst::to_string(non_max_suppression_node const&
|
81 | 85 | return description.str();
|
82 | 86 | }
|
83 | 87 |
|
| 88 | +// ----------------------------------------------- |
| 89 | +// non_max_suppression_gather |
| 90 | +// ----------------------------------------------- |
| 91 | +GPU_DEFINE_PRIMITIVE_TYPE_ID(non_max_suppression_gather) |
| 92 | + |
| 93 | +layout non_max_suppression_gather_inst::calc_output_layout(non_max_suppression_gather_node const& node, kernel_impl_params const& impl_param) { |
| 94 | + OPENVINO_THROW("Only calc_output_layouts should be used!"); |
| 95 | +} |
| 96 | + |
| 97 | +template<typename ShapeType> |
| 98 | +std::vector<layout> non_max_suppression_gather_inst::calc_output_layouts(non_max_suppression_gather_node const& /*node*/, |
| 99 | + const kernel_impl_params& impl_param) { |
| 100 | + std::vector<layout> layouts; |
| 101 | + |
| 102 | + auto desc = impl_param.typed_desc<non_max_suppression_gather>(); |
| 103 | + std::vector<ShapeType> output_shapes = { ShapeType{}, ShapeType{}, ShapeType{} }; |
| 104 | + |
| 105 | + auto& memory_deps = impl_param.memory_deps; |
| 106 | + if (memory_deps.count(0)) { |
| 107 | + auto actual_output = memory_deps.at(0); |
| 108 | + cldnn::mem_lock<int32_t, mem_lock_type::read> actual_output_lock(actual_output, impl_param.get_stream()); |
| 109 | + |
| 110 | + auto output_ps = actual_output->get_layout().get_partial_shape(); |
| 111 | + auto b = output_ps[0].get_length(); |
| 112 | + auto f = output_ps[1].get_length(); |
| 113 | + |
| 114 | + // find valid data size |
| 115 | + auto output_data = actual_output_lock.data(); |
| 116 | + int64_t actual_valid_num = b; |
| 117 | + for (int64_t i = 0; i < b ; i += 1) { |
| 118 | + if (output_data[i * f] == -1) { |
| 119 | + actual_valid_num = i; |
| 120 | + break; |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + output_shapes[0] = output_shapes[1] = ShapeType{actual_valid_num, f}; |
| 125 | + output_shapes[2] = ShapeType{1}; |
| 126 | + } else { |
| 127 | + output_shapes[0] = output_shapes[1] = ShapeType{ov::Dimension::dynamic(), 3}; |
| 128 | + output_shapes[2] = ShapeType{1}; |
| 129 | + } |
| 130 | + |
| 131 | + for (size_t i = 0; i < desc->num_outputs; ++i) { |
| 132 | + layouts.push_back({output_shapes[i], |
| 133 | + impl_param.get_input_layout(i).data_type, |
| 134 | + format::get_default_format(output_shapes[i].size())}); |
| 135 | + } |
| 136 | + return layouts; |
| 137 | +} |
| 138 | + |
| 139 | +template std::vector<layout> non_max_suppression_gather_inst::calc_output_layouts<ov::PartialShape>(non_max_suppression_gather_node const& node, |
| 140 | + const kernel_impl_params& impl_param); |
| 141 | + |
| 142 | +std::string non_max_suppression_gather_inst::to_string(non_max_suppression_gather_node const& node) { |
| 143 | + auto desc = node.get_primitive(); |
| 144 | + auto node_info = node.desc_to_json(); |
| 145 | + |
| 146 | + json_composite info; |
| 147 | + |
| 148 | + node_info->add("non max suppression gather info", info); |
| 149 | + |
| 150 | + std::stringstream description; |
| 151 | + node_info->dump(description); |
| 152 | + return description.str(); |
| 153 | +} |
| 154 | + |
| 155 | +void non_max_suppression_gather_inst::on_execute() { |
| 156 | + update_output_memory(); |
| 157 | +} |
| 158 | + |
| 159 | +void non_max_suppression_gather_inst::update_output_memory() { |
| 160 | + if (!can_be_optimized()) |
| 161 | + return; |
| 162 | + |
| 163 | + for (size_t i = 0; i < inputs_memory_count(); i++) { |
| 164 | + if (node->get_program().is_new_shape_infer() && input_memory_ptr(i) == nullptr) |
| 165 | + return; |
| 166 | + |
| 167 | + if (output_memory_ptr(i) != nullptr && _network.get_engine().is_the_same_buffer(output_memory(i), input_memory(i))) |
| 168 | + return; |
| 169 | + |
| 170 | + _outputs[i] = {_network.get_engine().reinterpret_buffer(input_memory(i), _impl_params->get_output_layout(i))}; |
| 171 | + } |
| 172 | +} |
| 173 | + |
| 174 | +non_max_suppression_gather_inst::typed_primitive_inst(network& network, non_max_suppression_gather_node const& node) : parent(network, node) {} |
| 175 | + |
84 | 176 | } // namespace cldnn
|
0 commit comments