Skip to content

Commit 661d619

Browse files
authored
Support per field matrix layout (shader-slang#3101)
* Support per field matrix layout * Fix warnings. * Fix. * Fix tests. * Fix spiv gen. * Fix. * More test fixes. * Fix. * Run only GPU tests on self-hosted servers. * Remove -use-glsl-matrix-layout-modifier. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 0403e05 commit 661d619

File tree

76 files changed

+1370
-1355
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1370
-1355
lines changed

.github/workflows/windows-selfhosted.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ jobs:
4040
run: |
4141
$slangTestBinDir = ".\bin\windows-${{matrix.testPlatform}}\${{matrix.configuration}}\";
4242
$env:Path += ";$slangTestBinDir";
43-
& "$slangTestBinDir\slang-test.exe" -appveyor -bindir "$slangTestBinDir\" -platform ${{matrix.testPlatform}} -configuration ${{matrix.configuration}} -category ${{matrix.testCategory}} 2>&1;
43+
& "$slangTestBinDir\slang-test.exe" -appveyor -bindir "$slangTestBinDir\" -platform ${{matrix.testPlatform}} -configuration ${{matrix.configuration}} -category ${{matrix.testCategory}} -api all-cpu 2>&1;
4444

build/visual-studio/slang/slang.vcxproj

+4
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
406406
<ClInclude Include="..\..\..\source\slang\slang-ir-loop-unroll.h" />
407407
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-binding-query.h" />
408408
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h" />
409+
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-buffer-element-type.h" />
409410
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-com-methods.h" />
410411
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-error-handling.h" />
411412
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h" />
@@ -442,6 +443,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
442443
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-dispatch.h" />
443444
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-dynamic-associatedtype-lookup.h" />
444445
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-function-call.h" />
446+
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-matrix-layout.h" />
445447
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-resources.h" />
446448
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize.h" />
447449
<ClInclude Include="..\..\..\source\slang\slang-ir-spirv-legalize.h" />
@@ -612,6 +614,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
612614
<ClCompile Include="..\..\..\source\slang\slang-ir-loop-unroll.cpp" />
613615
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-binding-query.cpp" />
614616
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp" />
617+
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-buffer-element-type.cpp" />
615618
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-com-methods.cpp" />
616619
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-error-handling.cpp" />
617620
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp" />
@@ -648,6 +651,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
648651
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-dispatch.cpp" />
649652
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-dynamic-associatedtype-lookup.cpp" />
650653
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-function-call.cpp" />
654+
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-matrix-layout.cpp" />
651655
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-resources.cpp" />
652656
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize.cpp" />
653657
<ClCompile Include="..\..\..\source\slang\slang-ir-spirv-legalize.cpp" />

build/visual-studio/slang/slang.vcxproj.filters

+12
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,9 @@
306306
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h">
307307
<Filter>Header Files</Filter>
308308
</ClInclude>
309+
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-buffer-element-type.h">
310+
<Filter>Header Files</Filter>
311+
</ClInclude>
309312
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-com-methods.h">
310313
<Filter>Header Files</Filter>
311314
</ClInclude>
@@ -414,6 +417,9 @@
414417
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-function-call.h">
415418
<Filter>Header Files</Filter>
416419
</ClInclude>
420+
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-matrix-layout.h">
421+
<Filter>Header Files</Filter>
422+
</ClInclude>
417423
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-resources.h">
418424
<Filter>Header Files</Filter>
419425
</ClInclude>
@@ -920,6 +926,9 @@
920926
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp">
921927
<Filter>Source Files</Filter>
922928
</ClCompile>
929+
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-buffer-element-type.cpp">
930+
<Filter>Source Files</Filter>
931+
</ClCompile>
923932
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-com-methods.cpp">
924933
<Filter>Source Files</Filter>
925934
</ClCompile>
@@ -1028,6 +1037,9 @@
10281037
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-function-call.cpp">
10291038
<Filter>Source Files</Filter>
10301039
</ClCompile>
1040+
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-matrix-layout.cpp">
1041+
<Filter>Source Files</Filter>
1042+
</ClCompile>
10311043
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-resources.cpp">
10321044
<Filter>Source Files</Filter>
10331045
</ClCompile>

slang.h

-1
Original file line numberDiff line numberDiff line change
@@ -4131,7 +4131,6 @@ namespace slang
41314131

41324132
virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) = 0;
41334133

4134-
41354134
};
41364135

41374136
#define SLANG_UUID_ICompileRequest ICompileRequest::getTypeGuid()

source/slang-glslang/slang-glslang.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ static spv_target_env _getUniversalTargetEnv(glslang::EShTargetLanguageVersion i
509509
return SPV_ENV_UNIVERSAL_1_2;
510510
}
511511

512-
static int glslang_compileGLSLToSPIRV(const glslang_CompileRequest_1_2& request)
512+
static int glslang_compileGLSLToSPIRV(glslang_CompileRequest_1_2 request)
513513
{
514514
// Check that the encoding matches
515515
assert(glslang::EShTargetSpv_1_4 == _makeTargetLanguageVersion(1, 4));

source/slang/core.meta.slang

+20-17
Original file line numberDiff line numberDiff line change
@@ -894,8 +894,11 @@ struct vector
894894
__init(vector<T,N> value);
895895
}
896896

897+
const int kRowMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_ROW_MAJOR);
898+
const int kColumnMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_COLUMN_MAJOR);
899+
897900
/// A matrix with `R` rows and `C` columns, with elements of type `T`.
898-
__generic<T = float, let R : int = 4, let C : int = 4>
901+
__generic<T = float, let R : int = 4, let C : int = 4, let L : int = $(SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)>
899902
__magic_type(MatrixExpressionType)
900903
struct matrix
901904
{
@@ -1111,7 +1114,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
11111114
for( int R = 2; R <= 4; ++R )
11121115
for( int C = 2; C <= 4; ++C )
11131116
{
1114-
sb << "__generic<T> __extension matrix<T, " << R << "," << C << ">\n{\n";
1117+
sb << "__generic<T, let L:int> __extension matrix<T, " << R << "," << C << ", L>\n{\n";
11151118

11161119
// initialize from R*C scalars
11171120
sb << "__intrinsic_op(" << int(kIROp_MakeMatrix) << ") __init(";
@@ -1137,7 +1140,7 @@ for( int C = 2; C <= 4; ++C )
11371140
for( int cc = C; cc <= 4; ++cc )
11381141
{
11391142
if(rr == R && cc == C) continue;
1140-
sb << "__intrinsic_op(" << int(kIROp_MatrixReshape) << ") __init(matrix<T," << rr << "," << cc << "> value);\n";
1143+
sb << "__intrinsic_op(" << int(kIROp_MatrixReshape) << ") __init(matrix<T," << rr << "," << cc << ", L> value);\n";
11411144
}
11421145
sb << "}\n";
11431146
}
@@ -1207,16 +1210,16 @@ extension vector<T, N> : IDifferentiable
12071210
}
12081211
}
12091212

1210-
__generic<T:__BuiltinFloatingPointType, let R: int, let C: int>
1211-
extension matrix<T, R, C> : IDifferentiable
1213+
__generic<T:__BuiltinFloatingPointType, let R: int, let C: int, let L : int>
1214+
extension matrix<T, R, C, L> : IDifferentiable
12121215
{
1213-
typedef matrix<T, R, C> Differential;
1216+
typedef matrix<T, R, C, L> Differential;
12141217

12151218
[__unsafeForceInlineEarly]
12161219
[BackwardDifferentiable]
12171220
static Differential dzero()
12181221
{
1219-
return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero()));
1222+
return matrix<T, R, C, L>(__slang_noop_cast<T>(T.dzero()));
12201223
}
12211224

12221225
[__unsafeForceInlineEarly]
@@ -2425,9 +2428,9 @@ vector<T,N> operator$(op.name)(in out vector<T,N> value)
24252428
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
24262429

24272430
$(fixity.qual)
2428-
__generic<T : __BuiltinArithmeticType, let R : int, let C : int>
2431+
__generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int>
24292432
[__unsafeForceInlineEarly]
2430-
matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C> value)
2433+
matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
24312434
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
24322435

24332436
$(fixity.qual)
@@ -2609,9 +2612,9 @@ __generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M :
26092612
__intrinsic_op($(info.op))
26102613
matrix<L,N,M> operator$(info.name)(matrix<L,N,M> left, matrix<R,N,M> right);
26112614

2612-
__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int>
2615+
__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int, let Layout : int>
26132616
[__unsafeForceInlineEarly]
2614-
matrix<L, N, M> operator$(info.name)=(in out matrix<L, N, M> left, matrix<R, N, M> right)
2617+
matrix<L, N, M> operator$(info.name)=(in out matrix<L, N, M, Layout> left, matrix<R, N, M> right)
26152618
{
26162619
left = left $(info.name) right;
26172620
return left;
@@ -2641,9 +2644,9 @@ __generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M :
26412644
__intrinsic_op($(info.op))
26422645
matrix<L,N,M> operator$(info.name)(matrix<L,N,M> left, R right);
26432646

2644-
__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int>
2647+
__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int, let Layout : int>
26452648
[__unsafeForceInlineEarly]
2646-
matrix<L,N,M> operator$(info.name)=(in out matrix<L,N,M> left, R right)
2649+
matrix<L,N,M> operator$(info.name)=(in out matrix<L,N,M, Layout> left, R right)
26472650
{
26482651
left = left $(info.name) right;
26492652
return left;
@@ -2696,17 +2699,17 @@ ${{{{
26962699
return left;
26972700
}
26982701

2699-
__generic<T : $(op.interface), let R : int, let C : int>
2702+
__generic<T : $(op.interface), let R : int, let C : int, let Layout : int>
27002703
[__unsafeForceInlineEarly]
2701-
matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C> left, matrix<T,R,C> right)
2704+
matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C,Layout> left, matrix<T,R,C> right)
27022705
{
27032706
left = left $(op.name) right;
27042707
return left;
27052708
}
27062709

2707-
__generic<T : $(op.interface), let R : int, let C : int>
2710+
__generic<T : $(op.interface), let R : int, let C : int, let Layout : int>
27082711
[__unsafeForceInlineEarly]
2709-
matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C> left, T right)
2712+
matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C, Layout> left, T right)
27102713
{
27112714
left = left $(op.name) right;
27122715
return left;

source/slang/diff.meta.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -1403,11 +1403,11 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
14031403
c = cos(x);
14041404
}
14051405

1406-
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
1406+
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L1 : int, let L2 : int>
14071407
[BackwardDifferentiable]
14081408
[PrimalSubstituteOf(sincos)]
14091409
[PreferRecompute]
1410-
void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c)
1410+
void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M, L1> s, out matrix<T, N, M, L2> c)
14111411
{
14121412
s = sin(x);
14131413
c = cos(x);

source/slang/hlsl.meta.slang

+21-17
Original file line numberDiff line numberDiff line change
@@ -235,19 +235,20 @@ struct StructuredBuffer
235235
out uint numStructs,
236236
out uint stride);
237237

238-
__target_intrinsic(glsl, "$0._data[$1]")
239-
__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;")
240-
[__readNone]
238+
//__target_intrinsic(glsl, "$0._data[$1]")
239+
//__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;")
240+
__intrinsic_op($(kIROp_StructuredBufferLoad))
241241
T Load(int location);
242242

243-
[__readNone]
243+
__intrinsic_op($(kIROp_StructuredBufferLoadStatus))
244244
T Load(int location, out uint status);
245245

246246
__subscript(uint index) -> T
247247
{
248-
__target_intrinsic(glsl, "$0._data[$1]")
249-
__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;")
248+
//__target_intrinsic(glsl, "$0._data[$1]")
249+
//__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;")
250250
[__readNone]
251+
__intrinsic_op($(kIROp_StructuredBufferLoad))
251252
get;
252253
};
253254
};
@@ -709,7 +710,6 @@ static const struct {
709710
for(auto item : kMutableStructuredBufferCases) {
710711
}}}}
711712

712-
713713
__generic<T>
714714
__magic_type(HLSL$(item.name)Type)
715715
__intrinsic_type($(item.op))
@@ -724,18 +724,22 @@ struct $(item.name)
724724

725725
uint IncrementCounter();
726726

727-
__target_intrinsic(glsl, "$0._data[$1]")
728-
__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;")
727+
//__target_intrinsic(glsl, "$0._data[$1]")
728+
//__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;")
729729
[__NoSideEffect]
730+
__intrinsic_op($(kIROp_RWStructuredBufferLoad))
730731
T Load(int location);
732+
731733
[__NoSideEffect]
734+
__intrinsic_op($(kIROp_RWStructuredBufferLoadStatus))
732735
T Load(int location, out uint status);
733736

734737
__subscript(uint index) -> T
735738
{
736-
__target_intrinsic(glsl, "$0._data[$1]")
737-
__target_intrinsic(spirv_direct, "*StorageBuffer OpAccessChain resultType resultId _0 const(int, 0) _1")
739+
//__target_intrinsic(glsl, "$0._data[$1]")
740+
//__target_intrinsic(spirv_direct, "*StorageBuffer OpAccessChain resultType resultId _0 const(int, 0) _1")
738741
[__NoSideEffect]
742+
__intrinsic_op($(kIROp_RWStructuredBufferGetElementPtr))
739743
ref;
740744
}
741745
};
@@ -2248,10 +2252,10 @@ vector<T, N> frexp(vector<T, N> x, out vector<T, N> exp)
22482252
VECTOR_MAP_BINARY(T, N, frexp, x, exp);
22492253
}
22502254

2251-
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
2255+
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L : int>
22522256
__target_intrinsic(hlsl)
22532257
[__readNone]
2254-
matrix<T, N, M> frexp(matrix<T, N, M> x, out matrix<T, N, M> exp)
2258+
matrix<T, N, M> frexp(matrix<T, N, M> x, out matrix<T, N, M, L> exp)
22552259
{
22562260
MATRIX_MAP_BINARY(T, N, M, frexp, x, exp);
22572261
}
@@ -2924,10 +2928,10 @@ vector<T,N> modf(vector<T,N> x, out vector<T,N> ip)
29242928
VECTOR_MAP_BINARY(T, N, modf, x, ip);
29252929
}
29262930

2927-
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
2931+
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L : int>
29282932
__target_intrinsic(hlsl)
29292933
[__readNone]
2930-
matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M> ip)
2934+
matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M,L> ip)
29312935
{
29322936
MATRIX_MAP_BINARY(T, N, M, modf, x, ip);
29332937
}
@@ -3664,10 +3668,10 @@ void sincos(vector<T,N> x, out vector<T,N> s, out vector<T,N> c)
36643668
c = cos(x);
36653669
}
36663670

3667-
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
3671+
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L1: int, let L2 : int>
36683672
__target_intrinsic(hlsl)
36693673
[__readNone]
3670-
void sincos(matrix<T,N,M> x, out matrix<T,N,M> s, out matrix<T,N,M> c)
3674+
void sincos(matrix<T,N,M> x, out matrix<T,N,M,L1> s, out matrix<T,N,M,L2> c)
36713675
{
36723676
s = sin(x);
36733677
c = cos(x);

source/slang/slang-ast-builder.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,21 @@ VectorExpressionType* ASTBuilder::getVectorType(
383383
return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType"));
384384
}
385385

386+
MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout)
387+
{
388+
// Canonicalize constant size arguments to int.
389+
if (auto rowCountConstantInt = as<ConstantIntVal>(rowCount))
390+
{
391+
rowCount = getIntVal(getIntType(), rowCountConstantInt->getValue());
392+
}
393+
if (auto colCountConstantInt = as<ConstantIntVal>(colCount))
394+
{
395+
colCount = getIntVal(getIntType(), colCountConstantInt->getValue());
396+
}
397+
Val* args[] = { elementType, rowCount, colCount, layout };
398+
return as<MatrixExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType"));
399+
}
400+
386401
DifferentialPairType* ASTBuilder::getDifferentialPairType(
387402
Type* valueType,
388403
Witness* primalIsDifferentialWitness)

source/slang/slang-ast-builder.h

+2
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,8 @@ class ASTBuilder : public RefObject
424424

425425
VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount);
426426

427+
MatrixExpressionType* getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout);
428+
427429
ConstantBufferType* getConstantBufferType(Type* elementType);
428430

429431
ParameterBlockType* getParameterBlockType(Type* elementType);

source/slang/slang-ast-support-types.h

+3
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ namespace Slang
8080
// No conversion at all
8181
kConversionCost_None = 0,
8282

83+
// Convert between matrices of different layout
84+
kConversionCost_MatrixLayout = 5,
85+
8386
// Conversion from a buffer to the type it carries needs to add a minimal
8487
// extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>`
8588
// from one on `Foo`

source/slang/slang-ast-type.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,11 @@ IntVal* MatrixExpressionType::getColumnCount()
238238
return as<IntVal>(_getGenericTypeArg(this, 2));
239239
}
240240

241+
IntVal* MatrixExpressionType::getLayout()
242+
{
243+
return as<IntVal>(_getGenericTypeArg(this, 3));
244+
}
245+
241246
Type* MatrixExpressionType::getRowType()
242247
{
243248
if (!rowType)

source/slang/slang-ast-type.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -435,14 +435,15 @@ class VectorExpressionType : public ArithmeticExpressionType
435435
IntVal* getElementCount();
436436
};
437437

438-
// A matrix type, e.g., `matrix<T,R,C>`
438+
// A matrix type, e.g., `matrix<T,R,C,L>`
439439
class MatrixExpressionType : public ArithmeticExpressionType
440440
{
441441
SLANG_AST_CLASS(MatrixExpressionType)
442442

443443
Type* getElementType();
444444
IntVal* getRowCount();
445445
IntVal* getColumnCount();
446+
IntVal* getLayout();
446447

447448
Type* getRowType();
448449

0 commit comments

Comments
 (0)