Skip to content

Commit 3777479

Browse files
mmikolajczmitruska
andauthored
[PyAPI][OP]Add ScatterNDUpdate-15 to PyAPI (openvinotoolkit#24188)
### Details: - *Add ScatterNDUpdate-15 to PyAPI* - *Parametrize to test both ScatterND versions + add test for reduce attribute* ### Tickets: - *111092* --------- Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
1 parent 2d8ac08 commit 3777479

File tree

3 files changed

+154
-56
lines changed

3 files changed

+154
-56
lines changed

src/bindings/python/src/openvino/runtime/opset15/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
# Inlcudes new operators added in Opset15
66

77
# TODO (ticket 138273): Add previous opset operators at the end of opset15 development
8+
from openvino.runtime.opset1.ops import parameter
9+
from openvino.runtime.opset15.ops import scatter_nd_update

src/bindings/python/src/openvino/runtime/opset15/ops.py

+29
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,38 @@
44

55
"""Factory functions for ops added to openvino opset15."""
66
from functools import partial
7+
from typing import Optional, Literal
78

9+
from openvino.runtime import Node, Type
810
from openvino.runtime.opset_utils import _get_node_factory
11+
from openvino.runtime.utils.decorators import nameable_op
12+
from openvino.runtime.utils.types import NodeInput, as_nodes
913

1014
_get_node_factory_opset15 = partial(_get_node_factory, "opset15")
1115

1216
# -------------------------------------------- ops ------------------------------------------------
17+
18+
19+
@nameable_op
20+
def scatter_nd_update(
21+
data: NodeInput,
22+
indices: NodeInput,
23+
updates: NodeInput,
24+
reduction: Optional[Literal["none", "sum", "sub", "prod", "min", "max"]] = None,
25+
name: Optional[str] = None,
26+
) -> Node:
27+
"""Return a node which performs ScatterNDUpdate.
28+
29+
:param data: Node input representing the tensor to be updated.
30+
:param indices: Node input representing the indices at which updates will be applied.
31+
:param updates: Node input representing the updates to be applied.
32+
:param reduction: The type of operation to perform on the inputs. One of "none", "sum",
33+
"sub", "prod", "min", "max".
34+
:param name: Optional name for the output node.
35+
:return: New node performing the ScatterNDUpdate.
36+
"""
37+
inputs = as_nodes(data, indices, updates, name=name)
38+
attributes = {}
39+
if reduction:
40+
attributes["reduction"] = reduction
41+
return _get_node_factory_opset15().create("ScatterNDUpdate", inputs, attributes)

src/bindings/python/tests/test_graph/test_ops_scatter_nd_update.py

+123-56
Original file line numberDiff line numberDiff line change
@@ -6,159 +6,226 @@
66
import pytest
77

88
from openvino import PartialShape, Type
9+
from openvino.runtime import opset4, opset15
910

10-
import openvino.runtime.opset13 as ov
11+
scatter_version_opset = pytest.mark.parametrize("opset", [opset4, opset15])
1112

1213

13-
def test_scatter_nd_update():
14+
@scatter_version_opset
15+
def test_scatter_nd_update(opset):
1416
data_shape = [4, 4, 4]
1517
indices_shape = [2, 1]
1618
updates_shape = [2, 4, 4]
1719

18-
data_param = ov.parameter(shape=data_shape, dtype=Type.f32, name="data")
19-
indices_param = ov.parameter(shape=indices_shape, dtype=Type.i32, name="indices")
20-
updates_param = ov.parameter(shape=updates_shape, dtype=Type.f32, name="updates")
20+
data_param = opset.parameter(shape=data_shape, dtype=Type.f32, name="data")
21+
indices_param = opset.parameter(shape=indices_shape, dtype=Type.i32, name="indices")
22+
updates_param = opset.parameter(shape=updates_shape, dtype=Type.f32, name="updates")
2123

22-
scatter_nd_node = ov.scatter_nd_update(data_param, indices_param, updates_param)
24+
scatter_nd_node = opset.scatter_nd_update(data_param, indices_param, updates_param)
2325

2426
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
2527
assert scatter_nd_node.get_output_size() == 1
2628
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data_shape))
2729
assert scatter_nd_node.get_output_element_type(0) == Type.f32
2830

2931

30-
def test_scatter_nd_update_basic():
32+
@scatter_version_opset
33+
def test_scatter_nd_update_basic(opset):
3134
data = np.array([1, 2, 3, 4, 5])
3235
indices = np.array([[0], [2]])
3336
updates = np.array([9, 10])
3437

35-
result = ov.scatter_nd_update(data, indices, updates)
36-
expected = np.array([9, 2, 10, 4, 5])
37-
np.testing.assert_array_equal(result, expected)
38+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
39+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
40+
assert scatter_nd_node.get_output_size() == 1
41+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
42+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
3843

3944

40-
def test_scatter_nd_update_multidimensional():
45+
@scatter_version_opset
46+
def test_scatter_nd_update_multidimensional(opset):
4147
data = np.array([[1, 2], [3, 4]])
4248
indices = np.array([[0, 1], [1, 0]])
4349
updates = np.array([9, 10])
4450

45-
result = ov.scatter_nd_update(data, indices, updates)
46-
expected = np.array([[1, 9], [10, 4]])
47-
np.testing.assert_array_equal(result, expected)
51+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
52+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
53+
assert scatter_nd_node.get_output_size() == 1
54+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
55+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
4856

4957

50-
def test_scatter_nd_update_mismatched_updates_shape():
58+
@scatter_version_opset
59+
def test_scatter_nd_update_mismatched_updates_shape(opset):
5160
data = np.array([1, 2, 3])
5261
indices = np.array([[0], [1]])
5362
updates = np.array([4])
5463

5564
with pytest.raises(RuntimeError):
56-
ov.scatter_nd_update(data, indices, updates)
65+
opset.scatter_nd_update(data, indices, updates)
5766

5867

59-
def test_scatter_nd_update_non_integer_indices():
68+
@scatter_version_opset
69+
def test_scatter_nd_update_non_integer_indices(opset):
6070
data = np.array([1, 2, 3])
6171
indices = np.array([[0.5]])
6272
updates = np.array([4])
6373

6474
with pytest.raises(RuntimeError):
65-
ov.scatter_nd_update(data, indices, updates)
75+
opset.scatter_nd_update(data, indices, updates)
6676

6777

68-
def test_scatter_nd_update_negative_indices():
78+
@scatter_version_opset
79+
def test_scatter_nd_update_negative_indices(opset):
6980
data = np.array([1, 2, 3, 4])
7081
indices = np.array([[-1]])
7182
updates = np.array([5])
7283

73-
result = ov.scatter_nd_update(data, indices, updates)
74-
expected = np.array([1, 2, 3, 5])
75-
np.testing.assert_array_equal(result, expected)
84+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
85+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
86+
assert scatter_nd_node.get_output_size() == 1
87+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
88+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
7689

7790

78-
def test_scatter_nd_update_multi_index_per_update():
91+
@scatter_version_opset
92+
def test_scatter_nd_update_multi_index_per_update(opset):
7993
data = np.array([[1, 2], [3, 4]])
8094
indices = np.array([[0, 0], [0, 1]])
8195
updates = np.array([5, 6])
8296

83-
result = ov.scatter_nd_update(data, indices, updates)
84-
expected = np.array([[5, 6], [3, 4]])
85-
np.testing.assert_array_equal(result, expected)
97+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
98+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
99+
assert scatter_nd_node.get_output_size() == 1
100+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
101+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
86102

87103

88-
def test_scatter_nd_update_non_contiguous_indices():
104+
@scatter_version_opset
105+
def test_scatter_nd_update_non_contiguous_indices(opset):
89106
data = np.array([10, 20, 30, 40, 50])
90107
indices = np.array([[0], [3]])
91108
updates = np.array([100, 400])
92109

93-
result = ov.scatter_nd_update(data, indices, updates)
94-
expected = np.array([100, 20, 30, 400, 50])
95-
np.testing.assert_array_equal(result, expected)
110+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
111+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
112+
assert scatter_nd_node.get_output_size() == 1
113+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
114+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
96115

97116

98-
def test_scatter_nd_update_large_updates():
117+
@scatter_version_opset
118+
def test_scatter_nd_update_large_updates(opset):
99119
data = np.zeros(1000, dtype=np.float64)
100120
indices = np.reshape(np.arange(1000), (-1, 1))
101121
updates = np.arange(1000, dtype=np.float64)
102122

103-
result = ov.scatter_nd_update(data, indices, updates)
104-
expected = np.arange(1000, dtype=np.float64)
105-
np.testing.assert_array_equal(result, expected)
123+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
124+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
125+
assert scatter_nd_node.get_output_size() == 1
126+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
127+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
106128

107129

108-
def test_scatter_nd_update_overlapping_indices():
130+
@scatter_version_opset
131+
def test_scatter_nd_update_opseterlapping_indices(opset):
109132
data = np.array([1, 2, 3, 4, 5])
110133
indices = np.array([[1], [1], [3]])
111134
updates = np.array([10, 20, 30])
112135

113-
result = ov.scatter_nd_update(data, indices, updates)
114-
expected = np.array([1, 20, 3, 30, 5])
115-
np.testing.assert_array_equal(result, expected)
136+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
137+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
138+
assert scatter_nd_node.get_output_size() == 1
139+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
140+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
116141

117142

118-
def test_scatter_nd_update_3d_data():
143+
@scatter_version_opset
144+
def test_scatter_nd_update_3d_data(opset):
119145
data = np.zeros((2, 2, 2), dtype=np.float64)
120146
indices = np.array([[0, 0, 1], [1, 1, 0]])
121147
updates = np.array([1, 2], dtype=np.float64)
122148

123-
result = ov.scatter_nd_update(data, indices, updates)
124-
expected = np.array([[[0, 1], [0, 0]], [[0, 0], [2, 0]]], dtype=np.float64)
125-
np.testing.assert_array_equal(result, expected)
149+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
150+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
151+
assert scatter_nd_node.get_output_size() == 1
152+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
153+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
126154

127155

128-
def test_scatter_nd_update_all_indices():
156+
@scatter_version_opset
157+
def test_scatter_nd_update_all_indices(opset):
129158
data = np.ones((2, 3), dtype=np.float64)
130159
indices = np.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]])
131160
updates = np.array([10, 20, 30, 40, 50, 60], dtype=np.float64)
132161

133-
result = ov.scatter_nd_update(data, indices, updates)
134-
expected = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.float64)
135-
np.testing.assert_array_equal(result, expected)
162+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
163+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
164+
assert scatter_nd_node.get_output_size() == 1
165+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
166+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
136167

137168

138-
def test_scatter_nd_update_invalid_updates_shape():
169+
@scatter_version_opset
170+
def test_scatter_nd_update_invalid_updates_shape(opset):
139171
data = np.array([1, 2, 3, 4])
140172
indices = np.array([[1], [2]])
141173
updates = np.array([5])
142174

143175
with pytest.raises(RuntimeError):
144-
ov.scatter_nd_update(data, indices, updates)
176+
opset.scatter_nd_update(data, indices, updates)
145177

146178

147-
def test_scatter_nd_update_negative_updates():
179+
@scatter_version_opset
180+
def test_scatter_nd_update_negative_updates(opset):
148181
data = np.array([1, 2, 3, 4, 5])
149182
indices = np.array([[1], [3]])
150183
updates = np.array([-1, -2])
151184

152-
result = ov.scatter_nd_update(data, indices, updates)
153-
expected = np.array([1, -1, 3, -2, 5])
154-
np.testing.assert_array_equal(result, expected)
185+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
186+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
187+
assert scatter_nd_node.get_output_size() == 1
188+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
189+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
155190

156191

157-
def test_scatter_nd_update_empty_indices_and_updates():
192+
@scatter_version_opset
193+
def test_scatter_nd_update_empty_indices_and_updates(opset):
158194
data = np.array([1, 2, 3], dtype=np.float64)
159195
indices = np.array([], dtype=np.int64).reshape(0, 1)
160196
updates = np.array([], dtype=np.float64)
161197

162-
result = ov.scatter_nd_update(data, indices, updates)
163-
expected = np.array([1, 2, 3], dtype=np.float64)
164-
np.testing.assert_array_equal(result, expected)
198+
scatter_nd_node = opset.scatter_nd_update(data, indices, updates)
199+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
200+
assert scatter_nd_node.get_output_size() == 1
201+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
202+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)
203+
204+
205+
@pytest.mark.parametrize(
206+
"reduction",
207+
[
208+
None,
209+
"none",
210+
"sUm",
211+
"SUB",
212+
"pRod",
213+
"miN",
214+
"Max",
215+
],
216+
)
217+
def test_scatter_nd_update_reduction(reduction):
218+
data = np.array([1, 2, 3, 4, 5])
219+
indices = np.array([[0], [2]])
220+
updates = np.array([9, 10])
221+
222+
scatter_nd_node = opset15.scatter_nd_update(data, indices, updates, reduction)
223+
assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
224+
op_attrs = scatter_nd_node.get_attributes()
225+
if reduction is None:
226+
assert op_attrs["reduction"] == "none"
227+
else:
228+
assert op_attrs["reduction"] == reduction.lower()
229+
assert scatter_nd_node.get_output_size() == 1
230+
assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape))
231+
assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype)

0 commit comments

Comments
 (0)