Skip to content

Commit d941cf2

Browse files
Merge pull request #872 from IntelPython/adding-stack-feature
Added dpctl.tensor.stack feature and tests
2 parents 880d71e + 7e21c25 commit d941cf2

File tree

4 files changed

+176
-15
lines changed

4 files changed

+176
-15
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3232
* Wrote manual page about working with `dpctl.SyclQueue` [#829](https://github.com/IntelPython/dpctl/pull/829).
3333
* Added cmake scripts to dpctl package layout and a way to query the location [#853](https://github.com/IntelPython/dpctl/pull/853).
3434
* Implemented `dpctl.tensor.concat` function from array-API [#867](https://github.com/IntelPython/dpctl/867).
35+
* Implemented `dpctl.tensor.stack` function from array-API [#872](https://github.com/IntelPython/dpctl/872).
3536

3637

3738
### Changed

dpctl/tensor/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
permute_dims,
4646
roll,
4747
squeeze,
48+
stack,
4849
)
4950
from dpctl.tensor._reshape import reshape
5051
from dpctl.tensor._usmarray import usm_ndarray
@@ -68,6 +69,7 @@
6869
"reshape",
6970
"roll",
7071
"concat",
72+
"stack",
7173
"broadcast_arrays",
7274
"broadcast_to",
7375
"expand_dims",

dpctl/tensor/_manipulation_functions.py

+63-15
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,7 @@ def roll(X, shift, axes=None):
288288
return res
289289

290290

291-
def concat(arrays, axis=0):
292-
"""
293-
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
294-
295-
Joins a sequence of arrays along an existing axis.
296-
"""
291+
def _arrays_validation(arrays):
297292
n = len(arrays)
298293
if n == 0:
299294
raise TypeError("Missing 1 required positional argument: 'arrays'")
@@ -324,11 +319,23 @@ def concat(arrays, axis=0):
324319
for i in range(1, n):
325320
if X0.ndim != arrays[i].ndim:
326321
raise ValueError(
327-
"All the input arrays must have same number of "
328-
"dimensions, but the array at index 0 has "
329-
f"{X0.ndim} dimension(s) and the array at index "
330-
f"{i} has {arrays[i].ndim} dimension(s)"
322+
"All the input arrays must have same number of dimensions, "
323+
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
324+
f"array at index {i} has {arrays[i].ndim} dimension(s)"
331325
)
326+
return res_dtype, res_usm_type, exec_q
327+
328+
329+
def concat(arrays, axis=0):
330+
"""
331+
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
332+
333+
Joins a sequence of arrays along an existing axis.
334+
"""
335+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
336+
337+
n = len(arrays)
338+
X0 = arrays[0]
332339

333340
axis = normalize_axis_index(axis, X0.ndim)
334341
X0_shape = X0.shape
@@ -337,11 +344,10 @@ def concat(arrays, axis=0):
337344
for j in range(X0.ndim):
338345
if X0_shape[j] != Xi_shape[j] and j != axis:
339346
raise ValueError(
340-
"All the input array dimensions for the "
341-
"concatenation axis must match exactly, but "
342-
f"along dimension {j}, the array at index 0 "
343-
f"has size {X0_shape[j]} and the array at "
344-
f"index {i} has size {Xi_shape[j]}"
347+
"All the input array dimensions for the concatenation "
348+
f"axis must match exactly, but along dimension {j}, the "
349+
f"array at index 0 has size {X0_shape[j]} and the array "
350+
f"at index {i} has size {Xi_shape[j]}"
345351
)
346352

347353
res_shape_axis = 0
@@ -373,3 +379,45 @@ def concat(arrays, axis=0):
373379
dpctl.SyclEvent.wait_for(hev_list)
374380

375381
return res
382+
383+
384+
def stack(arrays, axis=0):
385+
"""
386+
stack(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
387+
388+
Joins a sequence of arrays along a new axis.
389+
"""
390+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
391+
392+
n = len(arrays)
393+
X0 = arrays[0]
394+
res_ndim = X0.ndim + 1
395+
axis = normalize_axis_index(axis, res_ndim)
396+
X0_shape = X0.shape
397+
398+
for i in range(1, n):
399+
if X0_shape != arrays[i].shape:
400+
raise ValueError("All input arrays must have the same shape")
401+
402+
res_shape = tuple(
403+
X0_shape[i - 1 * (i >= axis)] if i != axis else n
404+
for i in range(res_ndim)
405+
)
406+
407+
res = dpt.empty(
408+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
409+
)
410+
411+
hev_list = []
412+
for i in range(n):
413+
c_shapes_copy = tuple(
414+
i if j == axis else np.s_[:] for j in range(res_ndim)
415+
)
416+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
417+
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
418+
)
419+
hev_list.append(hev)
420+
421+
dpctl.SyclEvent.wait_for(hev_list)
422+
423+
return res

dpctl/tests/test_usm_ndarray_manipulation.py

+110
Original file line numberDiff line numberDiff line change
@@ -890,3 +890,113 @@ def test_concat_3arrays(data):
890890
R = dpt.concat([X, Y, Z], axis=axis)
891891

892892
assert_array_equal(Rnp, dpt.asnumpy(R))
893+
894+
895+
def test_stack_incorrect_shape():
896+
try:
897+
q = dpctl.SyclQueue()
898+
except dpctl.SyclQueueCreationError:
899+
pytest.skip("Queue could not be created")
900+
901+
X = dpt.ones((1,), sycl_queue=q)
902+
Y = dpt.ones((2,), sycl_queue=q)
903+
904+
pytest.raises(ValueError, dpt.stack, [X, Y], 0)
905+
906+
907+
@pytest.mark.parametrize(
908+
"data",
909+
[
910+
[(6,), 0],
911+
[(2, 3), 1],
912+
[(3, 2), -1],
913+
[(1, 6), 2],
914+
[(2, 1, 3), 2],
915+
],
916+
)
917+
def test_stack_1array(data):
918+
try:
919+
q = dpctl.SyclQueue()
920+
except dpctl.SyclQueueCreationError:
921+
pytest.skip("Queue could not be created")
922+
923+
shape, axis = data
924+
925+
Xnp = np.arange(6).reshape(shape)
926+
X = dpt.asarray(Xnp, sycl_queue=q)
927+
928+
Ynp = np.stack([Xnp], axis=axis)
929+
Y = dpt.stack([X], axis=axis)
930+
931+
assert_array_equal(Ynp, dpt.asnumpy(Y))
932+
933+
Ynp = np.stack((Xnp,), axis=axis)
934+
Y = dpt.stack((X,), axis=axis)
935+
936+
assert_array_equal(Ynp, dpt.asnumpy(Y))
937+
938+
939+
@pytest.mark.parametrize(
940+
"data",
941+
[
942+
[(1,), 0],
943+
[(0, 2), 0],
944+
[(2, 0), 0],
945+
[(2, 3), 0],
946+
[(2, 3), 1],
947+
[(2, 3), 2],
948+
[(2, 3), -1],
949+
[(2, 3), -2],
950+
[(2, 2, 2), 1],
951+
],
952+
)
953+
def test_stack_2arrays(data):
954+
try:
955+
q = dpctl.SyclQueue()
956+
except dpctl.SyclQueueCreationError:
957+
pytest.skip("Queue could not be created")
958+
959+
shape, axis = data
960+
961+
Xnp = np.ones(shape)
962+
X = dpt.asarray(Xnp, sycl_queue=q)
963+
964+
Ynp = np.zeros(shape)
965+
Y = dpt.asarray(Ynp, sycl_queue=q)
966+
967+
Znp = np.stack([Xnp, Ynp], axis=axis)
968+
print(Znp.shape)
969+
Z = dpt.stack([X, Y], axis=axis)
970+
971+
assert_array_equal(Znp, dpt.asnumpy(Z))
972+
973+
974+
@pytest.mark.parametrize(
975+
"data",
976+
[
977+
[(1,), 0],
978+
[(0, 2), 0],
979+
[(2, 1, 2), 1],
980+
],
981+
)
982+
def test_stack_3arrays(data):
983+
try:
984+
q = dpctl.SyclQueue()
985+
except dpctl.SyclQueueCreationError:
986+
pytest.skip("Queue could not be created")
987+
988+
shape, axis = data
989+
990+
Xnp = np.ones(shape)
991+
X = dpt.asarray(Xnp, sycl_queue=q)
992+
993+
Ynp = np.zeros(shape)
994+
Y = dpt.asarray(Ynp, sycl_queue=q)
995+
996+
Znp = np.full(shape, 2.0)
997+
Z = dpt.asarray(Znp, sycl_queue=q)
998+
999+
Rnp = np.stack([Xnp, Ynp, Znp], axis=axis)
1000+
R = dpt.stack([X, Y, Z], axis=axis)
1001+
1002+
assert_array_equal(Rnp, dpt.asnumpy(R))

0 commit comments

Comments
 (0)