Skip to content

Commit 02bb741

Browse files
authored
Preserve type cast during AST constant folding. (shader-slang#2912)
* Preserve type cast during AST constant folding. Fixes shader-slang#2891. * Fix. * Fix truncating. * fix test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 5dd401e commit 02bb741

9 files changed

+190
-23
lines changed

source/slang/slang-ast-val.cpp

+102
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,108 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
12171217
return this;
12181218
}
12191219

1220+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeCastIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
1221+
bool TypeCastIntVal::_equalsValOverride(Val* val)
1222+
{
1223+
if (auto typeCastIntVal = as<TypeCastIntVal>(val))
1224+
{
1225+
if (!type->equals(typeCastIntVal->type))
1226+
return false;
1227+
if (!base->equalsVal(typeCastIntVal->base))
1228+
return false;
1229+
return true;
1230+
}
1231+
return false;
1232+
}
1233+
1234+
void TypeCastIntVal::_toTextOverride(StringBuilder& out)
1235+
{
1236+
type->toText(out);
1237+
out << "(";
1238+
base->toText(out);
1239+
out << ")";
1240+
}
1241+
1242+
HashCode TypeCastIntVal::_getHashCodeOverride()
1243+
{
1244+
HashCode result = type->getHashCode();
1245+
result = combineHash(result, base->getHashCode());
1246+
return result;
1247+
}
1248+
1249+
Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink)
1250+
{
1251+
SLANG_UNUSED(sink);
1252+
1253+
if (auto c = as<ConstantIntVal>(base))
1254+
{
1255+
IntegerLiteralValue resultValue = c->value;
1256+
auto baseType = as<BasicExpressionType>(resultType);
1257+
if (baseType)
1258+
{
1259+
switch (baseType->baseType)
1260+
{
1261+
case BaseType::Int:
1262+
resultValue = (int)resultValue;
1263+
break;
1264+
case BaseType::UInt:
1265+
resultValue = (unsigned int)resultValue;
1266+
break;
1267+
case BaseType::Int64:
1268+
case BaseType::IntPtr:
1269+
resultValue = (Int64)resultValue;
1270+
break;
1271+
case BaseType::UInt64:
1272+
case BaseType::UIntPtr:
1273+
resultValue = (UInt64)resultValue;
1274+
break;
1275+
case BaseType::Int16:
1276+
resultValue = (int16_t)resultValue;
1277+
break;
1278+
case BaseType::UInt16:
1279+
resultValue = (uint16_t)resultValue;
1280+
break;
1281+
case BaseType::Int8:
1282+
resultValue = (int8_t)resultValue;
1283+
break;
1284+
case BaseType::UInt8:
1285+
resultValue = (uint8_t)resultValue;
1286+
break;
1287+
default:
1288+
return nullptr;
1289+
}
1290+
}
1291+
return astBuilder->getIntVal(resultType, resultValue);
1292+
}
1293+
return nullptr;
1294+
}
1295+
1296+
Val* TypeCastIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
1297+
{
1298+
int diff = 0;
1299+
auto substBase = base->substituteImpl(astBuilder, subst, &diff);
1300+
if (substBase != base)
1301+
diff++;
1302+
auto substType = as<Type>(type->substituteImpl(astBuilder, subst, &diff));
1303+
if (substType != type)
1304+
diff++;
1305+
*ioDiff += diff;
1306+
if (diff)
1307+
{
1308+
auto newVal = tryFoldImpl(astBuilder, substType, substBase, nullptr);
1309+
if (newVal)
1310+
return newVal;
1311+
else
1312+
{
1313+
auto result = astBuilder->create<TypeCastIntVal>(substType, substBase);
1314+
return result;
1315+
}
1316+
}
1317+
// Nothing found: don't substitute.
1318+
return this;
1319+
}
1320+
1321+
12201322
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncCallIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
12211323

12221324
bool FuncCallIntVal::_equalsValOverride(Val* val)

source/slang/slang-ast-val.h

+15
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,21 @@ class GenericParamIntVal : public IntVal
6363
{}
6464
};
6565

66+
class TypeCastIntVal : public IntVal
67+
{
68+
SLANG_AST_CLASS(TypeCastIntVal)
69+
70+
bool _equalsValOverride(Val* val);
71+
void _toTextOverride(StringBuilder& out);
72+
HashCode _getHashCodeOverride();
73+
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
74+
75+
Val* base;
76+
TypeCastIntVal(Type* inType, Val* inBase) : IntVal(inType), base(inBase) {}
77+
78+
static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink);
79+
};
80+
6681
// An compile time int val as result of some general computation.
6782
class FuncCallIntVal : public IntVal
6883
{

source/slang/slang-check-expr.cpp

+22-22
Original file line numberDiff line numberDiff line change
@@ -1385,26 +1385,12 @@ namespace Slang
13851385
auto targetBasicType = as<BasicExpressionType>(invokeExpr.getExpr()->type.type);
13861386
if (!targetBasicType)
13871387
return nullptr;
1388-
switch (targetBasicType->baseType)
1389-
{
1390-
case BaseType::Bool:
1391-
resultValue = constArgVals[0] != 0;
1392-
break;
1393-
case BaseType::Int:
1394-
case BaseType::UInt:
1395-
case BaseType::UInt16:
1396-
case BaseType::Int16:
1397-
case BaseType::UInt8:
1398-
case BaseType::Int8:
1399-
case BaseType::UIntPtr:
1400-
case BaseType::IntPtr:
1401-
case BaseType::Int64:
1402-
case BaseType::UInt64:
1403-
resultValue = constArgVals[0];
1404-
break;
1405-
default:
1406-
return nullptr;
1407-
}
1388+
auto foldVal = as<IntVal>(
1389+
TypeCastIntVal::tryFoldImpl(m_astBuilder, targetBasicType, argVals[0], getSink()));
1390+
if (foldVal)
1391+
return foldVal;
1392+
auto result = m_astBuilder->getOrCreate<TypeCastIntVal>(targetBasicType, argVals[0]);
1393+
return result;
14081394
}
14091395
else
14101396
{
@@ -1619,9 +1605,23 @@ namespace Slang
16191605

16201606
if(auto castExpr = expr.as<TypeCastExpr>())
16211607
{
1608+
auto substType = getType(m_astBuilder, expr);
1609+
if (!substType)
1610+
return nullptr;
1611+
if (!isScalarIntegerType(substType))
1612+
return nullptr;
16221613
auto val = tryConstantFoldExpr(getArg(castExpr, 0), circularityInfo);
1623-
if(val)
1624-
return val;
1614+
if (val)
1615+
{
1616+
if (!castExpr.getExpr()->type)
1617+
return nullptr;
1618+
auto foldVal = as<IntVal>(
1619+
TypeCastIntVal::tryFoldImpl(m_astBuilder, substType, val, getSink()));
1620+
if (foldVal)
1621+
return foldVal;
1622+
auto result = m_astBuilder->getOrCreate<TypeCastIntVal>(substType, val);
1623+
return result;
1624+
}
16251625
}
16261626
else if (auto invokeExpr = expr.as<InvokeExpr>())
16271627
{

source/slang/slang-lower-to-ir.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
14321432
tryEnv);
14331433
}
14341434

1435+
LoweredValInfo visitTypeCastIntVal(TypeCastIntVal* val)
1436+
{
1437+
TryClauseEnvironment tryEnv;
1438+
auto baseVal = lowerVal(context, val->base);
1439+
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
1440+
auto type = lowerType(context, val->type);
1441+
return LoweredValInfo::simple(getBuilder()->emitCast(type, baseVal.val));
1442+
}
1443+
14351444
LoweredValInfo visitWitnessLookupIntVal(WitnessLookupIntVal* val)
14361445
{
14371446
auto witnessVal = lowerVal(context, val->witness);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//TEST(compute):COMPARE_COMPUTE:
2+
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -entry computeMain -profile cs_5_0
3+
4+
int check<let b : bool>(int x)
5+
{
6+
// CHECK: int v{{.*}}[int(2)]
7+
int v[((int)b) + 1];
8+
for (int i = 0; i < ((int)b) + 1; i++)
9+
v[i] = i;
10+
return (int)v[x];
11+
}
12+
13+
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
14+
RWStructuredBuffer<int> outputBuffer;
15+
16+
[numthreads(1, 1, 1)]
17+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
18+
{
19+
int tid = dispatchThreadID.x;
20+
int inVal = tid;
21+
int outVal = check<true>(inVal + 1);
22+
outputBuffer[tid] = outVal;
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//TEST(compute):COMPARE_COMPUTE: -output-using-type
2+
3+
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
4+
RWStructuredBuffer<int> outputBuffer;
5+
6+
static const int c = (int8_t)255;
7+
8+
[numthreads(1, 1, 1)]
9+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
10+
{
11+
int tid = dispatchThreadID.x;
12+
int inVal = tid;
13+
int outVal = c;
14+
outputBuffer[tid] = outVal;
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
type: int32_t
2+
-1

tests/pipeline/ray-tracing/ray-query-subroutine.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
RWStructuredBuffer<int> gOutput;
88
RaytracingAccelerationStructure gScene;
99

10-
float3 helper<let N : int>(RayQuery<N> q)
10+
float3 helper<let N : uint>(RayQuery<N> q)
1111
{
1212
RayDesc ray;
1313
ray.Origin = 0;

0 commit comments

Comments
 (0)