Skip to content

Commit 731f1fc

Browse files
jsmall-zzzTim Foley
and
Tim Foley
authored
CUDA half comparison support (shader-slang#1834)
* #include an absolute path didn't work - because paths were taken to always be relative. * Split out StringEscapeUtil. * Added StringEscapeUtil. * Fix typo in unix quoting type. * Small comment improvements. * Try to fix linux linking issue. * Fix typo. * Attempt to fix linux link issue. * Update VS proj even though nothing really changed. * Fix another typo issue. * Fix for windows issue. Fixed bug. * Make separate Utils for escaping. * Fix typo. * Split out into StringEscapeHandler. * Windows shell does handle removing quotes (so remove code to remove them). * Handle unescaping if not initiating using the shell. * Slight improvement around shell like decoding. * Simplify command extraction. * Add shared-library category type. * Fix bug in command extraction. * Typo in transcendental category. * Enable unit-test on in smoke test category. * Make parsing failing output as a failing test. * Fixes for transcendental tests. Disable tests that do not work. * Changed category parsing. * Removed the TestResult parameter from _gatherTestsForFile. Made testsList only output. * Remove testing if all tests were disabled. * Make args of CommandLine always unescaped. * Add category. * Don't need escaping on unix/linux. * Remove some no longer used functions. * Add requireSMVersion to CUDAExtensionTracker. * half-calc.slang now works for CUDA. * bit-cast-16-bit works on CUDA. * WIP handling of CUDA vector<half> types. * Half swizzle CUDA. * Half vector test. * Fix swizzle half bug. * Fix compilation issue with narrowing to Index. * Add unary ops. * Add some vector scalar maths ops. * Add half vector conversions for CUDA. * Fix erroneous comment. * Support for half comparisons. * First pass test for half compare. * Fix bug in CUDA specialized emit control. Updated tests to have pre and post inc/dec. * Removed unneeded parts of the cuda prelude. * Half structured buffer works on CUDA. Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
1 parent dc571f1 commit 731f1fc

12 files changed

+279
-70
lines changed

prelude/slang-cuda-prelude.h

+43-45
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
// are passed down.
66

77
#ifdef SLANG_CUDA_ENABLE_HALF
8+
// We don't want half2 operators, because it will implement comparison operators that return a bool(!). We want to generate
9+
// those functions. Doing so means that we will have to define all the other half2 operators.
10+
# define __CUDA_NO_HALF2_OPERATORS__
811
# include <cuda_fp16.h>
912
#endif
1013

@@ -155,6 +158,7 @@ union Union64
155158
struct __half3 { __half2 xy; __half z; };
156159
struct __half4 { __half2 xy; __half2 zw; };
157160

161+
// *** convert ***
158162

159163
// half -> other
160164

@@ -196,7 +200,43 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const double2& v) { r
196200
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const double3& v) { return __half3{ __float22half2_rn(float2{v.x, v.y}), __float2half_rn(v.z) }; }
197201
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const double4& v) { return __half4{ __float22half2_rn(float2{v.x, v.y}), __float22half2_rn(float2{v.z, v.w}) }; }
198202

199-
// half2
203+
// *** make ***
204+
205+
// Mechanism to make half vectors
206+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
207+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; }
208+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; }
209+
210+
// *** constructFromScalar ***
211+
212+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); }
213+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; }
214+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; }
215+
216+
// *** half2 ***
217+
218+
// half2 maths ops
219+
220+
// NOTE! That by default these are in cuda_fp16.hpp, but we disable them, because we need to define the comparison operators
221+
// as we need versions that will return vector<bool>
222+
223+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, const __half2& rh) { return __hadd2(lh, rh); }
224+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, const __half2& rh) { return __hsub2(lh, rh); }
225+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, const __half2& rh) { return __hmul2(lh, rh); }
226+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, const __half2& rh) { return __h2div(lh, rh); }
227+
228+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator+=(__half2& lh, const __half2& rh) { lh = __hadd2(lh, rh); return lh; }
229+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator-=(__half2& lh, const __half2& rh) { lh = __hsub2(lh, rh); return lh; }
230+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator*=(__half2& lh, const __half2& rh) { lh = __hmul2(lh, rh); return lh; }
231+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator/=(__half2& lh, const __half2& rh) { lh = __h2div(lh, rh); return lh; }
232+
233+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator++(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return h; }
234+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator--(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return h; }
235+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator++(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return ret; }
236+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator--(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return ret; }
237+
238+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2 &h) { return h; }
239+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2 &h) { return __hneg2(h); }
200240

201241
// vec op scalar
202242
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, __half rh) { return __hadd2(lh, __half2half2(rh)); }
@@ -210,16 +250,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(__half lh, const __half2& r
210250
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(__half lh, const __half2& rh) { return __hmul2(__half2half2(lh), rh); }
211251
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(__half lh, const __half2& rh) { return __h2div(__half2half2(lh), rh); }
212252

213-
// Mechanism to make half vectors
214-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
215-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; }
216-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; }
217-
218-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); }
219-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; }
220-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; }
221-
222-
// Half3 maths ops
253+
// *** half3 ***
223254

224255
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(const __half3& lh, const __half3& rh) { return __half3{__hadd2(lh.xy, rh.xy), __hadd(lh.z, rh.z)}; }
225256
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(const __half3& lh, const __half3& rh) { return __half3{__hsub2(lh.xy, rh.xy), __hsub(lh.z, rh.z)}; }
@@ -241,18 +272,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(__half lh, const __half3& r
241272
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator*(__half lh, const __half3& rh) { return __half3{__hmul2(__half2half2(lh), rh.xy), __hmul(lh, rh.z)}; }
242273
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator/(__half lh, const __half3& rh) { return __half3{__h2div(__half2half2(lh), rh.xy), __hdiv(lh, rh.z)}; }
243274

244-
245-
#if 0
246-
// We need to return the vector<bool> type
247-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator==(const __half3& lh, const __half3& rh) { return __hbeq2(lh.xy, rh.xy) && __heq(lh.z, rh.z); }
248-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator!=(const __half3& lh, const __half3& rh) { return __hbneu2(lh.xy, rh.xy) && __hneu(lh.z, rh.z); }
249-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>(const __half3& lh, const __half3& rh) { return __hbgt2(lh.xy, rh.xy) && __hgt(lh.z, rh.z); }
250-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<(const __half3& lh, const __half3& rh) { return __hblt2(lh.xy, rh.xy) && __hlt(lh.z, rh.z); }
251-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>=(const __half3& lh, const __half3& rh) { return __hbge2(lh.xy, rh.xy) && __hge(lh.z, rh.z); }
252-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<=(const __half3& lh, const __half3& rh) { return __hble2(lh.xy, rh.xy) && __hle(lh.z, rh.z); }
253-
#endif
254-
255-
// Half4 maths ops
275+
// *** half4 ***
256276

257277
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& lh, const __half4& rh) { return __half4{__hadd2(lh.xy, rh.xy), __hadd2(lh.zw, rh.zw)}; }
258278
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& lh, const __half4& rh) { return __half4{__hsub2(lh.xy, rh.xy), __hsub2(lh.zw, rh.zw)}; }
@@ -274,28 +294,6 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator/(__half lh, const __half4& r
274294
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& h) { return __half4{__hneg2(h.xy), __hneg2(h.zw)}; }
275295
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& h) { return h; }
276296

277-
#if 0
278-
// We need to return vector<bool> type
279-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator==(const __half4& lh, const __half4& rh) { return __hbeq2(lh.xy, rh.xy) && __hbeq2(lh.zw, rh.zw); }
280-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator!=(const __half4& lh, const __half4& rh) { return __hbneu2(lh.xy, rh.xy) && __hbneu2(lh.zw, rh.zw); }
281-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>(const __half4& lh, const __half4& rh) { return __hbgt2(lh.xy, rh.xy) && __hbgt2(lh.zw, rh.zw); }
282-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<(const __half4& lh, const __half4& rh) { return __hblt2(lh.xy, rh.xy) && __hblt2(lh.zw, rh.zw); }
283-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>=(const __half4& lh, const __half4& rh) { return __hbge2(lh.xy, rh.xy) && __hbge2(lh.zw, rh.zw); }
284-
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<=(const __half4& lh, const __half4& rh) { return __hble2(lh.xy, rh.xy) && __hble2(lh.zw, rh.zw); }
285-
#endif
286-
287-
// Use the round nearest as the default - it is the only one defined
288-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float22half2(const float2 a) { return __float22half2_rn(a); }
289-
290-
// Implement the vector versions
291-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float2half(float2 a) { return __float22half2(a); }
292-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __float2half(float3 a) { __half3 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.z = __float2half(a.z); return o; }
293-
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __float2half(float4 a) { __half4 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.zw = __float22half2(make_float2(a.z, a.w)); return o; }
294-
295-
SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 __half2float(__half2 a) { return __half22float2(a); }
296-
SLANG_FORCE_INLINE SLANG_CUDA_CALL float3 __half2float(__half3 a) { float2 xy = __half22float2(a.xy); float z = __half2float(a.z); return make_float3(xy.x, xy.y, z); }
297-
SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 __half2float(__half4 a) { float2 xy = __half22float2(a.xy); float2 zw = __half22float2(a.zw); return make_float4(xy.x, xy.y, zw.x, zw.y); }
298-
299297
#endif
300298

301299
// ----------------------------- F32 -----------------------------------------

source/slang/slang-emit-cpp.cpp

+84-16
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ When called we can have a structure that holds the thread local variables, and t
6161

6262
namespace Slang {
6363

64-
static const char s_elemNames[] = "xyzw";
64+
static const char s_xyzwNames[] = "xyzw";
6565

6666
static UnownedStringSlice _getTypePrefix(IROp op)
6767
{
@@ -219,6 +219,9 @@ void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
219219
case kIROp_VectorType:
220220
{
221221
auto vecType = static_cast<IRVectorType*>(type);
222+
223+
const UnownedStringSlice* elemNames = getVectorElementNames(vecType);
224+
222225
int count = int(getIntVal(vecType->getElementCount()));
223226

224227
SLANG_ASSERT(count > 0 && count < 4);
@@ -239,7 +242,7 @@ void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
239242
{
240243
writer->emit(", ");
241244
}
242-
writer->emitChar(s_elemNames[i]);
245+
writer->emit(elemNames[i]);
243246
}
244247
writer->emit(";\n");
245248

@@ -648,22 +651,43 @@ static IRBasicType* _getElementType(IRType* type)
648651
case kIROp_VectorType:
649652
{
650653
auto vecType = static_cast<IRVectorType*>(type);
654+
655+
IRBasicType* elemBasicType = as<IRBasicType>(vecType->getElementType());
656+
const BaseType baseType = elemBasicType->getBaseType();
657+
651658
const int elemCount = int(getIntVal(vecType->getElementCount()));
652-
return (!vecSwap) ? TypeDimension{1, elemCount} : TypeDimension{ elemCount, 1};
659+
return (!vecSwap) ? TypeDimension{baseType, 1, elemCount} : TypeDimension{ baseType, elemCount, 1};
653660
}
654661
case kIROp_MatrixType:
655662
{
656663
auto matType = static_cast<IRMatrixType*>(type);
657664
const int colCount = int(getIntVal(matType->getColumnCount()));
658665
const int rowCount = int(getIntVal(matType->getRowCount()));
659-
return TypeDimension{rowCount, colCount};
666+
667+
IRBasicType* elemBasicType = as<IRBasicType>(matType->getElementType());
668+
const BaseType baseType = elemBasicType->getBaseType();
669+
670+
return TypeDimension{baseType, rowCount, colCount};
671+
}
672+
default:
673+
{
674+
// Assume we don't know the type
675+
BaseType baseType = BaseType::Void;
676+
677+
IRBasicType* basicType = as<IRBasicType>(type);
678+
if (basicType)
679+
{
680+
baseType = basicType->getBaseType();
681+
}
682+
683+
return TypeDimension{baseType, 1, 1};
660684
}
661-
default: return TypeDimension{1, 1};
662685
}
663686
}
664687

665-
/* static */void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer)
688+
void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer)
666689
{
690+
667691
writer->emit(name);
668692
const int comb = (dimension.colCount > 1 ? 2 : 0) | (dimension.rowCount > 1 ? 1 : 0);
669693
switch (comb)
@@ -673,21 +697,32 @@ static IRBasicType* _getElementType(IRType* type)
673697
break;
674698
}
675699
case 1:
700+
{
701+
// Vector, row count is biggest
702+
const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.rowCount);
703+
writer->emit(".");
704+
const int index = (row > col) ? row : col;
705+
writer->emit(elemNames[index]);
706+
break;
707+
}
676708
case 2:
677709
{
678-
// Vector
679-
int index = (row > col) ? row : col;
710+
// Vector cols biggest dimension
711+
const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.colCount);
680712
writer->emit(".");
681-
writer->emitChar(s_elemNames[index]);
713+
const int index = (row > col) ? row : col;
714+
writer->emit(elemNames[index]);
682715
break;
683716
}
684717
case 3:
685-
{
718+
{
686719
// Matrix
720+
const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.colCount);
721+
687722
writer->emit(".rows[");
688723
writer->emit(row);
689724
writer->emit("].");
690-
writer->emitChar(s_elemNames[col]);
725+
writer->emit(elemNames[col]);
691726
break;
692727
}
693728
}
@@ -1158,9 +1193,11 @@ void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, c
11581193
{
11591194
Index paramElementCount = Index(getIntVal(paramVecType->getElementCount()));
11601195

1196+
const UnownedStringSlice* elemNames = getVectorElementNames(paramVecType);
1197+
11611198
writer->emitChar('a' + char(paramIndex));
11621199
writer->emit(".");
1163-
writer->emitChar(s_elemNames[paramSubIndex]);
1200+
writer->emit(elemNames[paramSubIndex]);
11641201

11651202
paramSubIndex ++;
11661203

@@ -1348,6 +1385,11 @@ void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const
13481385
auto swizzleInst = static_cast<IRSwizzle*>(inst);
13491386
const Index elementCount = Index(swizzleInst->getElementCount());
13501387

1388+
IRType* srcType = swizzleInst->getBase()->getDataType();
1389+
IRVectorType* srcVecType = as<IRVectorType>(srcType);
1390+
1391+
const UnownedStringSlice* elemNames = getVectorElementNames(srcVecType);
1392+
13511393
// TODO(JS): Not 100% sure this is correct on the parens handling front
13521394
IRType* retType = specOp->returnType;
13531395
emitType(retType);
@@ -1373,7 +1415,7 @@ void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const
13731415
UInt elementIndex = (UInt)irConst->value.intVal;
13741416
SLANG_RELEASE_ASSERT(elementIndex < 4);
13751417

1376-
writer->emitChar(s_elemNames[elementIndex]);
1418+
writer->emit(elemNames[elementIndex]);
13771419
}
13781420

13791421
writer->emit("}");
@@ -2119,6 +2161,32 @@ void CPPSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* de
21192161
}
21202162
}
21212163

2164+
const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount)
2165+
{
2166+
SLANG_UNUSED(baseType);
2167+
SLANG_UNUSED(elemCount);
2168+
2169+
static const UnownedStringSlice elemNames[] =
2170+
{
2171+
UnownedStringSlice::fromLiteral("x"),
2172+
UnownedStringSlice::fromLiteral("y"),
2173+
UnownedStringSlice::fromLiteral("z"),
2174+
UnownedStringSlice::fromLiteral("w"),
2175+
};
2176+
2177+
return elemNames;
2178+
}
2179+
2180+
const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(IRVectorType* vectorType)
2181+
{
2182+
Index elemCount = Index(getIntVal(vectorType->getElementCount()));
2183+
2184+
IRType* type = vectorType->getElementType()->getCanonicalType();
2185+
IRBasicType* basicType = as<IRBasicType>(type);
2186+
SLANG_ASSERT(basicType);
2187+
return getVectorElementNames(basicType->getBaseType(), elemCount);
2188+
}
2189+
21222190
bool CPPSourceEmitter::_tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec)
21232191
{
21242192
HLSLIntrinsic* specOp = m_intrinsicSet.add(inst);
@@ -2444,7 +2512,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
24442512
{
24452513
const auto& axis = axes[i];
24462514
builder.Clear();
2447-
const char elem[2] = { s_elemNames[axis.axis], 0 };
2515+
const char elem[2] = { s_xyzwNames[axis.axis], 0 };
24482516
builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n";
24492517
m_writer->emit(builder);
24502518
m_writer->indent();
@@ -2478,7 +2546,7 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
24782546
{
24792547
const auto& axis = axes[i];
24802548
builder.Clear();
2481-
const char elem[2] = { s_elemNames[axis.axis], 0 };
2549+
const char elem[2] = { s_xyzwNames[axis.axis], 0 };
24822550

24832551
builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n";
24842552
m_writer->emit(builder);
@@ -2511,7 +2579,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupA
25112579
for (int i = 0; i < kThreadGroupAxisCount; ++i)
25122580
{
25132581
builder.Clear();
2514-
const char elem[2] = { s_elemNames[i], 0 };
2582+
const char elem[2] = { s_xyzwNames[i], 0 };
25152583
builder << mulName << "." << elem << " * " << sizeAlongAxis[i];
25162584
if (addName.getLength() > 0)
25172585
{

source/slang/slang-emit-cpp.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class CPPSourceEmitter: public CLikeSourceEmitter
3333
{
3434
bool isScalar() const { return rowCount <= 1 && colCount <= 1; }
3535

36+
BaseType elemType;
3637
int rowCount;
3738
int colCount;
3839
};
@@ -74,11 +75,15 @@ class CPPSourceEmitter: public CLikeSourceEmitter
7475

7576
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE;
7677

78+
virtual const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount);
79+
7780
// Replaceable for classes derived from CPPSourceEmitter
7881
virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out);
7982
virtual SlangResult calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& out);
8083
virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder);
8184

85+
const UnownedStringSlice* getVectorElementNames(IRVectorType* vectorType);
86+
8287
void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp);
8388

8489
void _emitForwardDeclarations(const List<EmitAction>& actions);
@@ -96,10 +101,12 @@ class CPPSourceEmitter: public CLikeSourceEmitter
96101

97102
void _emitInOutParamType(IRType* type, String const& name, IRType* valueType);
98103

104+
99105
UnownedStringSlice _getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType);
100106

101107
static TypeDimension _getTypeDimension(IRType* type, bool vecSwap);
102-
static void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer);
108+
109+
void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer);
103110

104111
UnownedStringSlice _getScalarFuncName(HLSLIntrinsic::Op operation, IRBasicType* scalarType);
105112

0 commit comments

Comments
 (0)