Skip to content

Commit 18e97b6

Browse files
p-wysockirkazantspraaszmitruska
authored
Add SegmentMax-16 reference implementation (#29047)
### Details: - The original PR (#28788) has been mistakenly force-merged due to a mistake in merge queue settings. It was later reverted, so this is the "new" Ref PR. - Add reference implementation - Add tests ### Related PRs: - #28103 - #28698 - #28979 - #28999 ### Tickets: - CVS-158917 --------- Signed-off-by: p-wysocki <przemyslaw.wysocki@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> Co-authored-by: Pawel Raasz <pawel.raasz@intel.com> Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
1 parent e58f38f commit 18e97b6

File tree

11 files changed

+373
-14
lines changed

11 files changed

+373
-14
lines changed

src/core/include/openvino/op/util/attr_types.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ enum class PadMode { CONSTANT = 0, EDGE, REFLECT, SYMMETRIC };
2020
OPENVINO_API
2121
std::ostream& operator<<(std::ostream& s, const PadMode& type);
2222

23-
/// \brief Fill modes for the `SegmentMax` operator.
23+
/// \brief Fill modes to set default value for operators like `SegmentMax`.
2424
enum class FillMode { ZERO = 0, LOWEST };
2525

2626
OPENVINO_API

src/core/include/openvino/opsets/opset16_tbl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3)
1616
// New operations added in opset16
1717
_OPENVINO_OP_REG(Identity, ov::op::v16)
1818
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
19+
_OPENVINO_OP_REG(SegmentMax, ov::op::v16)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <algorithm>
8+
#include <limits>
9+
#include <vector>
10+
11+
#include "openvino/core/shape.hpp"
12+
13+
namespace ov::reference {
14+
15+
template <typename T, typename T_idx, std::enable_if_t<std::is_same<std::decay_t<T_idx>, int64_t>::value>* = nullptr>
16+
void segment_max(const T* data,
17+
const Shape& data_shape,
18+
const T_idx* segment_ids,
19+
T* out,
20+
const Shape& output_shape,
21+
const T empty_segment_value) {
22+
const T_idx num_segments = output_shape[0];
23+
const auto inner_dim_size = shape_size(data_shape.begin() + 1, data_shape.end());
24+
25+
// Initialize output with empty_segment_value
26+
std::fill(out, out + num_segments * inner_dim_size, empty_segment_value);
27+
28+
// Iterate over each element in the first dimension
29+
for (size_t i = 0; i < data_shape[0]; ++i) {
30+
const T_idx segment_id = segment_ids[i];
31+
if (segment_id >= num_segments) {
32+
continue;
33+
}
34+
// Iterate over each element in the inner dimensions
35+
for (size_t j = 0; j < inner_dim_size; ++j) {
36+
const size_t index = i * inner_dim_size + j;
37+
const size_t out_index = segment_id * inner_dim_size + j;
38+
// Update the maximum value for the current segment and inner dimension
39+
out[out_index] = std::max(out[out_index], data[index]);
40+
}
41+
}
42+
}
43+
44+
template <typename T, typename T_idx, std::enable_if_t<!std::is_same<std::decay_t<T_idx>, int64_t>::value>* = nullptr>
45+
void segment_max(const T* data,
46+
const Shape& data_shape,
47+
const T_idx* segment_ids,
48+
T* out,
49+
const Shape& output_shape,
50+
const T empty_segment_value) {
51+
std::vector<int64_t> segment_ids_int64(segment_ids, segment_ids + data_shape[0]);
52+
segment_max(data, data_shape, segment_ids_int64.data(), out, output_shape, empty_segment_value);
53+
}
54+
55+
} // namespace ov::reference

src/core/shape_inference/include/segment_max_shape_inference.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ std::vector<TRShape> shape_infer(const SegmentMax* op,
5050

5151
// validate num_segments input
5252
const auto num_segments_available = op->inputs().size() == 3;
53-
const auto num_segments = num_segments_available ? get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor)
54-
: ov::optional<TRShape>{};
53+
ov::optional<TRShape> num_segments;
54+
if (num_segments_available) {
55+
num_segments = get_input_const_data_as_shape<TRShape>(op, 2, tensor_accessor);
56+
}
57+
5558
if (num_segments_available) {
5659
const auto& num_segments_shape = input_shapes[2];
5760
NODE_SHAPE_INFER_CHECK(op,

src/core/tests/opset.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
7777
OpsetTestParams{ov::get_opset13, 186},
7878
OpsetTestParams{ov::get_opset14, 188},
7979
OpsetTestParams{ov::get_opset15, 199},
80-
OpsetTestParams{ov::get_opset16, 5}),
80+
OpsetTestParams{ov::get_opset16, 6}),
8181
OpsetTestNameGenerator{});
8282

8383
class MyOpOld : public ov::op::Op {

src/core/tests/type_prop/segment_max.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
namespace ov::test {
1818
using op::v0::Constant, op::v0::Parameter, op::v1::Add, op::v1::ReduceMax, op::v1::StridedSlice, op::v3::ShapeOf;
19+
using testing::HasSubstr;
1920

2021
class TypePropSegmentMaxTest : public TypePropOpTest<op::v16::SegmentMax> {};
2122

@@ -69,45 +70,44 @@ TEST_F(TypePropSegmentMaxTest, incorrect_inputs) {
6970
const auto num_segments_f32 = std::make_shared<Parameter>(element::f32, PartialShape{});
7071
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids, num_segments_f32, op::FillMode::LOWEST),
7172
ov::NodeValidationFailure,
72-
testing::HasSubstr("The element type of the num_segments input be i32 or i64."));
73+
HasSubstr("The element type of the num_segments input be i32 or i64."));
7374
}
7475
{
7576
const auto segment_ids_f32 = std::make_shared<Parameter>(element::f32, PartialShape{3});
7677
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_f32, num_segments, op::FillMode::LOWEST),
7778
ov::NodeValidationFailure,
78-
testing::HasSubstr("The element type of the segment_ids input be i32 or i64."));
79+
HasSubstr("The element type of the segment_ids input be i32 or i64."));
7980
}
8081
{
8182
const auto segment_ids_nd = std::make_shared<Parameter>(element::i32, PartialShape{2, 3});
8283
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_nd, num_segments, op::FillMode::LOWEST),
8384
ov::NodeValidationFailure,
84-
testing::HasSubstr("segment_ids must be a 1D input."));
85+
HasSubstr("segment_ids must be a 1D input."));
8586
}
8687
{
8788
const auto num_segments_nd = std::make_shared<Parameter>(element::i32, PartialShape{1});
8889
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids, num_segments_nd, op::FillMode::LOWEST),
8990
ov::NodeValidationFailure,
90-
testing::HasSubstr("num_segments must be a scalar input."));
91+
HasSubstr("num_segments must be a scalar input."));
9192
}
9293
{
9394
const auto segment_ids_unsorted =
9495
std::make_shared<Constant>(element::i32, Shape{3}, std::vector<int64_t>{1, 0, 1});
9596
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_unsorted, num_segments, op::FillMode::LOWEST),
9697
ov::NodeValidationFailure,
97-
testing::HasSubstr("segment_ids must be sorted."));
98+
HasSubstr("segment_ids must be sorted."));
9899
}
99100
{
100101
const auto data_scalar = std::make_shared<Parameter>(element::i32, PartialShape{});
101102
OV_EXPECT_THROW(std::ignore = make_op(data_scalar, segment_ids, num_segments, op::FillMode::LOWEST),
102103
ov::NodeValidationFailure,
103-
testing::HasSubstr("The data input cannot be a scalar."));
104+
HasSubstr("The data input cannot be a scalar."));
104105
}
105106
{
106107
const auto segment_ids_short = std::make_shared<Constant>(element::i32, Shape{2}, std::vector<int64_t>{1, 0});
107-
OV_EXPECT_THROW(
108-
std::ignore = make_op(data, segment_ids_short, num_segments, op::FillMode::LOWEST),
109-
ov::NodeValidationFailure,
110-
testing::HasSubstr("The number of elements in segment_ids must match the first dimension of data."));
108+
OV_EXPECT_THROW(std::ignore = make_op(data, segment_ids_short, num_segments, op::FillMode::LOWEST),
109+
ov::NodeValidationFailure,
110+
HasSubstr("The number of elements in segment_ids must match the first dimension of data."));
111111
}
112112
}
113113

src/plugins/template/backend/ops/ops_evaluates.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,7 @@ extern template bool evaluate_node<ov::op::v15::SearchSorted>(std::shared_ptr<ov
558558
extern template bool evaluate_node<ov::op::v16::Identity>(std::shared_ptr<ov::Node> node,
559559
ov::TensorVector& outputs,
560560
const ov::TensorVector& inputs);
561+
562+
extern template bool evaluate_node<ov::op::v16::SegmentMax>(std::shared_ptr<ov::Node> node,
563+
ov::TensorVector& outputs,
564+
const ov::TensorVector& inputs);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/reference/segment_max.hpp"
6+
7+
#include "element_visitor.hpp"
8+
#include "evaluate_node.hpp"
9+
#include "segment_max_shape_inference.hpp"
10+
11+
template <ov::element::Type_t ET_data, ov::element::Type_t ET_idx>
12+
bool evaluate_index_type(const std::shared_ptr<ov::op::v16::SegmentMax>& op,
13+
ov::TensorVector& outputs,
14+
const ov::TensorVector& inputs) {
15+
using T_data = typename ov::element_type_traits<ET_data>::value_type;
16+
using T_idx = typename ov::element_type_traits<ET_idx>::value_type;
17+
auto input_shapes = std::vector<ov::PartialShape>{op->get_input_shape(0), op->get_input_shape(1)};
18+
if (op->inputs().size() == 3) {
19+
input_shapes.emplace_back(op->get_input_shape(2));
20+
}
21+
const auto output_shape =
22+
ov::op::v16::shape_infer(op.get(), input_shapes, make_tensor_accessor(inputs)).front().to_shape();
23+
outputs.front().set_shape(output_shape);
24+
const auto empty_segment_value =
25+
op->get_fill_mode() == ov::op::FillMode::ZERO ? T_data(0) : std::numeric_limits<T_data>::lowest();
26+
ov::reference::segment_max(inputs[0].data<const T_data>(),
27+
inputs[0].get_shape(),
28+
inputs[1].data<const T_idx>(),
29+
outputs[0].data<T_data>(),
30+
outputs[0].get_shape(),
31+
empty_segment_value);
32+
return true;
33+
}
34+
35+
template <ov::element::Type_t ET_data>
36+
bool evaluate_data_type(const std::shared_ptr<ov::op::v16::SegmentMax>& op,
37+
ov::TensorVector& outputs,
38+
const ov::TensorVector& inputs) {
39+
const auto& index_type = op->get_input_element_type(1);
40+
using ov::op::v16::SegmentMax;
41+
using namespace ov::element;
42+
switch (index_type) {
43+
case i32:
44+
return evaluate_index_type<ET_data, i32>(ov::as_type_ptr<SegmentMax>(op), outputs, inputs);
45+
case i64:
46+
return evaluate_index_type<ET_data, i64>(ov::as_type_ptr<SegmentMax>(op), outputs, inputs);
47+
default:
48+
OPENVINO_THROW("Unhandled index type ", index_type, " in evaluate_node()");
49+
}
50+
}
51+
52+
template <>
53+
bool evaluate_node<ov::op::v16::SegmentMax>(std::shared_ptr<ov::Node> node,
54+
ov::TensorVector& outputs,
55+
const ov::TensorVector& inputs) {
56+
const auto& element_type = node->get_output_element_type(0);
57+
58+
using ov::op::v16::SegmentMax;
59+
using namespace ov::element;
60+
switch (element_type) {
61+
case i8:
62+
return evaluate_data_type<i8>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
63+
case i32:
64+
return evaluate_data_type<i32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
65+
case i64:
66+
return evaluate_data_type<i64>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
67+
case u8:
68+
return evaluate_data_type<u8>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
69+
case u32:
70+
return evaluate_data_type<u32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
71+
case u64:
72+
return evaluate_data_type<u64>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
73+
case f16:
74+
return evaluate_data_type<f16>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
75+
case f32:
76+
return evaluate_data_type<f32>(ov::as_type_ptr<SegmentMax>(node), outputs, inputs);
77+
default:
78+
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()");
79+
}
80+
}

src/plugins/template/backend/opset_int_tbl.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ _OPENVINO_OP_REG(SearchSorted, ov::op::v15)
177177

178178
_OPENVINO_OP_REG(Identity, ov::op::v16)
179179
_OPENVINO_OP_REG(ISTFT, ov::op::v16)
180+
_OPENVINO_OP_REG(SegmentMax, ov::op::v16)
180181

181182
_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
182183
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)

0 commit comments

Comments
 (0)