|
6 | 6 | import pytest
|
7 | 7 |
|
8 | 8 | from openvino import PartialShape, Type
|
| 9 | +from openvino.runtime import opset4, opset15 |
9 | 10 |
|
10 |
| -import openvino.runtime.opset13 as ov |
| 11 | +scatter_version_opset = pytest.mark.parametrize("opset", [opset4, opset15]) |
11 | 12 |
|
12 | 13 |
|
13 |
| -def test_scatter_nd_update(): |
| 14 | +@scatter_version_opset |
| 15 | +def test_scatter_nd_update(opset): |
14 | 16 | data_shape = [4, 4, 4]
|
15 | 17 | indices_shape = [2, 1]
|
16 | 18 | updates_shape = [2, 4, 4]
|
17 | 19 |
|
18 |
| - data_param = ov.parameter(shape=data_shape, dtype=Type.f32, name="data") |
19 |
| - indices_param = ov.parameter(shape=indices_shape, dtype=Type.i32, name="indices") |
20 |
| - updates_param = ov.parameter(shape=updates_shape, dtype=Type.f32, name="updates") |
| 20 | + data_param = opset.parameter(shape=data_shape, dtype=Type.f32, name="data") |
| 21 | + indices_param = opset.parameter(shape=indices_shape, dtype=Type.i32, name="indices") |
| 22 | + updates_param = opset.parameter(shape=updates_shape, dtype=Type.f32, name="updates") |
21 | 23 |
|
22 |
| - scatter_nd_node = ov.scatter_nd_update(data_param, indices_param, updates_param) |
| 24 | + scatter_nd_node = opset.scatter_nd_update(data_param, indices_param, updates_param) |
23 | 25 |
|
24 | 26 | assert scatter_nd_node.get_type_name() == "ScatterNDUpdate"
|
25 | 27 | assert scatter_nd_node.get_output_size() == 1
|
26 | 28 | assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data_shape))
|
27 | 29 | assert scatter_nd_node.get_output_element_type(0) == Type.f32
|
28 | 30 |
|
29 | 31 |
|
30 |
| -def test_scatter_nd_update_basic(): |
| 32 | +@scatter_version_opset |
| 33 | +def test_scatter_nd_update_basic(opset): |
31 | 34 | data = np.array([1, 2, 3, 4, 5])
|
32 | 35 | indices = np.array([[0], [2]])
|
33 | 36 | updates = np.array([9, 10])
|
34 | 37 |
|
35 |
| - result = ov.scatter_nd_update(data, indices, updates) |
36 |
| - expected = np.array([9, 2, 10, 4, 5]) |
37 |
| - np.testing.assert_array_equal(result, expected) |
| 38 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 39 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 40 | + assert scatter_nd_node.get_output_size() == 1 |
| 41 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 42 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
38 | 43 |
|
39 | 44 |
|
40 |
| -def test_scatter_nd_update_multidimensional(): |
| 45 | +@scatter_version_opset |
| 46 | +def test_scatter_nd_update_multidimensional(opset): |
41 | 47 | data = np.array([[1, 2], [3, 4]])
|
42 | 48 | indices = np.array([[0, 1], [1, 0]])
|
43 | 49 | updates = np.array([9, 10])
|
44 | 50 |
|
45 |
| - result = ov.scatter_nd_update(data, indices, updates) |
46 |
| - expected = np.array([[1, 9], [10, 4]]) |
47 |
| - np.testing.assert_array_equal(result, expected) |
| 51 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 52 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 53 | + assert scatter_nd_node.get_output_size() == 1 |
| 54 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 55 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
48 | 56 |
|
49 | 57 |
|
50 |
| -def test_scatter_nd_update_mismatched_updates_shape(): |
| 58 | +@scatter_version_opset |
| 59 | +def test_scatter_nd_update_mismatched_updates_shape(opset): |
51 | 60 | data = np.array([1, 2, 3])
|
52 | 61 | indices = np.array([[0], [1]])
|
53 | 62 | updates = np.array([4])
|
54 | 63 |
|
55 | 64 | with pytest.raises(RuntimeError):
|
56 |
| - ov.scatter_nd_update(data, indices, updates) |
| 65 | + opset.scatter_nd_update(data, indices, updates) |
57 | 66 |
|
58 | 67 |
|
59 |
| -def test_scatter_nd_update_non_integer_indices(): |
| 68 | +@scatter_version_opset |
| 69 | +def test_scatter_nd_update_non_integer_indices(opset): |
60 | 70 | data = np.array([1, 2, 3])
|
61 | 71 | indices = np.array([[0.5]])
|
62 | 72 | updates = np.array([4])
|
63 | 73 |
|
64 | 74 | with pytest.raises(RuntimeError):
|
65 |
| - ov.scatter_nd_update(data, indices, updates) |
| 75 | + opset.scatter_nd_update(data, indices, updates) |
66 | 76 |
|
67 | 77 |
|
68 |
| -def test_scatter_nd_update_negative_indices(): |
| 78 | +@scatter_version_opset |
| 79 | +def test_scatter_nd_update_negative_indices(opset): |
69 | 80 | data = np.array([1, 2, 3, 4])
|
70 | 81 | indices = np.array([[-1]])
|
71 | 82 | updates = np.array([5])
|
72 | 83 |
|
73 |
| - result = ov.scatter_nd_update(data, indices, updates) |
74 |
| - expected = np.array([1, 2, 3, 5]) |
75 |
| - np.testing.assert_array_equal(result, expected) |
| 84 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 85 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 86 | + assert scatter_nd_node.get_output_size() == 1 |
| 87 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 88 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
76 | 89 |
|
77 | 90 |
|
78 |
| -def test_scatter_nd_update_multi_index_per_update(): |
| 91 | +@scatter_version_opset |
| 92 | +def test_scatter_nd_update_multi_index_per_update(opset): |
79 | 93 | data = np.array([[1, 2], [3, 4]])
|
80 | 94 | indices = np.array([[0, 0], [0, 1]])
|
81 | 95 | updates = np.array([5, 6])
|
82 | 96 |
|
83 |
| - result = ov.scatter_nd_update(data, indices, updates) |
84 |
| - expected = np.array([[5, 6], [3, 4]]) |
85 |
| - np.testing.assert_array_equal(result, expected) |
| 97 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 98 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 99 | + assert scatter_nd_node.get_output_size() == 1 |
| 100 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 101 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
86 | 102 |
|
87 | 103 |
|
88 |
| -def test_scatter_nd_update_non_contiguous_indices(): |
| 104 | +@scatter_version_opset |
| 105 | +def test_scatter_nd_update_non_contiguous_indices(opset): |
89 | 106 | data = np.array([10, 20, 30, 40, 50])
|
90 | 107 | indices = np.array([[0], [3]])
|
91 | 108 | updates = np.array([100, 400])
|
92 | 109 |
|
93 |
| - result = ov.scatter_nd_update(data, indices, updates) |
94 |
| - expected = np.array([100, 20, 30, 400, 50]) |
95 |
| - np.testing.assert_array_equal(result, expected) |
| 110 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 111 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 112 | + assert scatter_nd_node.get_output_size() == 1 |
| 113 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 114 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
96 | 115 |
|
97 | 116 |
|
98 |
| -def test_scatter_nd_update_large_updates(): |
| 117 | +@scatter_version_opset |
| 118 | +def test_scatter_nd_update_large_updates(opset): |
99 | 119 | data = np.zeros(1000, dtype=np.float64)
|
100 | 120 | indices = np.reshape(np.arange(1000), (-1, 1))
|
101 | 121 | updates = np.arange(1000, dtype=np.float64)
|
102 | 122 |
|
103 |
| - result = ov.scatter_nd_update(data, indices, updates) |
104 |
| - expected = np.arange(1000, dtype=np.float64) |
105 |
| - np.testing.assert_array_equal(result, expected) |
| 123 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 124 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 125 | + assert scatter_nd_node.get_output_size() == 1 |
| 126 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 127 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
106 | 128 |
|
107 | 129 |
|
108 |
| -def test_scatter_nd_update_overlapping_indices(): |
| 130 | +@scatter_version_opset |
| 131 | +def test_scatter_nd_update_opseterlapping_indices(opset): |
109 | 132 | data = np.array([1, 2, 3, 4, 5])
|
110 | 133 | indices = np.array([[1], [1], [3]])
|
111 | 134 | updates = np.array([10, 20, 30])
|
112 | 135 |
|
113 |
| - result = ov.scatter_nd_update(data, indices, updates) |
114 |
| - expected = np.array([1, 20, 3, 30, 5]) |
115 |
| - np.testing.assert_array_equal(result, expected) |
| 136 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 137 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 138 | + assert scatter_nd_node.get_output_size() == 1 |
| 139 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 140 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
116 | 141 |
|
117 | 142 |
|
118 |
| -def test_scatter_nd_update_3d_data(): |
| 143 | +@scatter_version_opset |
| 144 | +def test_scatter_nd_update_3d_data(opset): |
119 | 145 | data = np.zeros((2, 2, 2), dtype=np.float64)
|
120 | 146 | indices = np.array([[0, 0, 1], [1, 1, 0]])
|
121 | 147 | updates = np.array([1, 2], dtype=np.float64)
|
122 | 148 |
|
123 |
| - result = ov.scatter_nd_update(data, indices, updates) |
124 |
| - expected = np.array([[[0, 1], [0, 0]], [[0, 0], [2, 0]]], dtype=np.float64) |
125 |
| - np.testing.assert_array_equal(result, expected) |
| 149 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 150 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 151 | + assert scatter_nd_node.get_output_size() == 1 |
| 152 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 153 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
126 | 154 |
|
127 | 155 |
|
128 |
| -def test_scatter_nd_update_all_indices(): |
| 156 | +@scatter_version_opset |
| 157 | +def test_scatter_nd_update_all_indices(opset): |
129 | 158 | data = np.ones((2, 3), dtype=np.float64)
|
130 | 159 | indices = np.array([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]])
|
131 | 160 | updates = np.array([10, 20, 30, 40, 50, 60], dtype=np.float64)
|
132 | 161 |
|
133 |
| - result = ov.scatter_nd_update(data, indices, updates) |
134 |
| - expected = np.array([[10, 20, 30], [40, 50, 60]], dtype=np.float64) |
135 |
| - np.testing.assert_array_equal(result, expected) |
| 162 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 163 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 164 | + assert scatter_nd_node.get_output_size() == 1 |
| 165 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 166 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
136 | 167 |
|
137 | 168 |
|
138 |
| -def test_scatter_nd_update_invalid_updates_shape(): |
| 169 | +@scatter_version_opset |
| 170 | +def test_scatter_nd_update_invalid_updates_shape(opset): |
139 | 171 | data = np.array([1, 2, 3, 4])
|
140 | 172 | indices = np.array([[1], [2]])
|
141 | 173 | updates = np.array([5])
|
142 | 174 |
|
143 | 175 | with pytest.raises(RuntimeError):
|
144 |
| - ov.scatter_nd_update(data, indices, updates) |
| 176 | + opset.scatter_nd_update(data, indices, updates) |
145 | 177 |
|
146 | 178 |
|
147 |
| -def test_scatter_nd_update_negative_updates(): |
| 179 | +@scatter_version_opset |
| 180 | +def test_scatter_nd_update_negative_updates(opset): |
148 | 181 | data = np.array([1, 2, 3, 4, 5])
|
149 | 182 | indices = np.array([[1], [3]])
|
150 | 183 | updates = np.array([-1, -2])
|
151 | 184 |
|
152 |
| - result = ov.scatter_nd_update(data, indices, updates) |
153 |
| - expected = np.array([1, -1, 3, -2, 5]) |
154 |
| - np.testing.assert_array_equal(result, expected) |
| 185 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 186 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 187 | + assert scatter_nd_node.get_output_size() == 1 |
| 188 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 189 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
155 | 190 |
|
156 | 191 |
|
157 |
| -def test_scatter_nd_update_empty_indices_and_updates(): |
| 192 | +@scatter_version_opset |
| 193 | +def test_scatter_nd_update_empty_indices_and_updates(opset): |
158 | 194 | data = np.array([1, 2, 3], dtype=np.float64)
|
159 | 195 | indices = np.array([], dtype=np.int64).reshape(0, 1)
|
160 | 196 | updates = np.array([], dtype=np.float64)
|
161 | 197 |
|
162 |
| - result = ov.scatter_nd_update(data, indices, updates) |
163 |
| - expected = np.array([1, 2, 3], dtype=np.float64) |
164 |
| - np.testing.assert_array_equal(result, expected) |
| 198 | + scatter_nd_node = opset.scatter_nd_update(data, indices, updates) |
| 199 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 200 | + assert scatter_nd_node.get_output_size() == 1 |
| 201 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 202 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.parametrize( |
| 206 | + "reduction", |
| 207 | + [ |
| 208 | + None, |
| 209 | + "none", |
| 210 | + "sUm", |
| 211 | + "SUB", |
| 212 | + "pRod", |
| 213 | + "miN", |
| 214 | + "Max", |
| 215 | + ], |
| 216 | +) |
| 217 | +def test_scatter_nd_update_reduction(reduction): |
| 218 | + data = np.array([1, 2, 3, 4, 5]) |
| 219 | + indices = np.array([[0], [2]]) |
| 220 | + updates = np.array([9, 10]) |
| 221 | + |
| 222 | + scatter_nd_node = opset15.scatter_nd_update(data, indices, updates, reduction) |
| 223 | + assert scatter_nd_node.get_type_name() == "ScatterNDUpdate" |
| 224 | + op_attrs = scatter_nd_node.get_attributes() |
| 225 | + if reduction is None: |
| 226 | + assert op_attrs["reduction"] == "none" |
| 227 | + else: |
| 228 | + assert op_attrs["reduction"] == reduction.lower() |
| 229 | + assert scatter_nd_node.get_output_size() == 1 |
| 230 | + assert scatter_nd_node.get_output_partial_shape(0).same_scheme(PartialShape(data.shape)) |
| 231 | + assert scatter_nd_node.get_output_element_type(0) == Type(data.dtype) |
0 commit comments