Skip to content

Commit 7c195d3

Browse files
authored
Fix CUDA prelude for makeMatrix (shader-slang#5509)
* Fix CUDA prelude for makeMatrix * Add regression test.
1 parent 65de545 commit 7c195d3

File tree

2 files changed

+40
-26
lines changed

2 files changed

+40
-26
lines changed

prelude/slang-cuda-prelude.h

+26-26
Original file line numberDiff line numberDiff line change
@@ -733,12 +733,12 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(
733733
Matrix<T, ROWS, COLS> rs;
734734
if (COLS == 3)
735735
{
736-
rs.rows[0].x = v0;
737-
rs.rows[0].y = v1;
738-
rs.rows[0].z = v2;
739-
rs.rows[1].x = v3;
740-
rs.rows[1].y = v4;
741-
rs.rows[1].z = v5;
736+
*_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
737+
*_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
738+
*_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
739+
*_slang_vector_get_element_ptr(&rs.rows[1], 0) = v3;
740+
*_slang_vector_get_element_ptr(&rs.rows[1], 1) = v4;
741+
*_slang_vector_get_element_ptr(&rs.rows[1], 2) = v5;
742742
}
743743
else
744744
{
@@ -766,14 +766,14 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(
766766
Matrix<T, ROWS, COLS> rs;
767767
if (COLS == 4)
768768
{
769-
rs.rows[0].x = v0;
770-
rs.rows[0].y = v1;
771-
rs.rows[0].z = v2;
772-
rs.rows[0].w = v3;
773-
rs.rows[1].x = v4;
774-
rs.rows[1].y = v5;
775-
rs.rows[1].z = v6;
776-
rs.rows[1].w = v7;
769+
*_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
770+
*_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
771+
*_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
772+
*_slang_vector_get_element_ptr(&rs.rows[0], 3) = v3;
773+
*_slang_vector_get_element_ptr(&rs.rows[1], 0) = v4;
774+
*_slang_vector_get_element_ptr(&rs.rows[1], 1) = v5;
775+
*_slang_vector_get_element_ptr(&rs.rows[1], 2) = v6;
776+
*_slang_vector_get_element_ptr(&rs.rows[1], 3) = v7;
777777
}
778778
else
779779
{
@@ -832,18 +832,18 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, ROWS, COLS> makeMatrix(
832832
Matrix<T, ROWS, COLS> rs;
833833
if (COLS == 4)
834834
{
835-
rs.rows[0].x = v0;
836-
rs.rows[0].y = v1;
837-
rs.rows[0].z = v2;
838-
rs.rows[0].w = v3;
839-
rs.rows[1].x = v4;
840-
rs.rows[1].y = v5;
841-
rs.rows[1].z = v6;
842-
rs.rows[1].w = v7;
843-
rs.rows[2].x = v8;
844-
rs.rows[2].y = v9;
845-
rs.rows[2].z = v10;
846-
rs.rows[2].w = v11;
835+
*_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
836+
*_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
837+
*_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
838+
*_slang_vector_get_element_ptr(&rs.rows[0], 3) = v3;
839+
*_slang_vector_get_element_ptr(&rs.rows[1], 0) = v4;
840+
*_slang_vector_get_element_ptr(&rs.rows[1], 1) = v5;
841+
*_slang_vector_get_element_ptr(&rs.rows[1], 2) = v6;
842+
*_slang_vector_get_element_ptr(&rs.rows[1], 3) = v7;
843+
*_slang_vector_get_element_ptr(&rs.rows[2], 0) = v8;
844+
*_slang_vector_get_element_ptr(&rs.rows[2], 1) = v9;
845+
*_slang_vector_get_element_ptr(&rs.rows[2], 2) = v10;
846+
*_slang_vector_get_element_ptr(&rs.rows[2], 3) = v11;
847847
}
848848
else
849849
{

tests/cuda/make-matrix.slang

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -compute
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cuda -compute
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<uint4x3> outputBuffer : register(u0);
6+
7+
[numthreads(1, 1, 1)]
8+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
9+
{
10+
uint idx = dispatchThreadID.x + 1;
11+
uint4x3 mat1 = uint4x3(idx, idx, idx, idx, idx, idx, idx, idx, idx, idx, idx, idx);
12+
outputBuffer[0] = mat1;
13+
// CHECK: 1
14+
}

0 commit comments

Comments
 (0)