Skip to content

Commit 36a06f1

Browse files
author
Tim Foley
authored
Diagnose circularly-defined constants (shader-slang#1384)
* Diagnose circularly-defined constants Work on shader-slang#1374 This change diagnoses cases like the following: ```hlsl static const int kCircular = kCircular; static const int kInfinite = kInfinite + 1; static const int kHere = kThere; static const int kThere = kHere; ``` By diagnosing these as errors in the front-end we protect against infinite recursion leading to stack overflow crashes. The basic approach is to have front-end constant folding track variables that are in use when folding a sub-expression, and then diagnosing an error if the same variable is encountered again while it is in use. In order to make sure the error occurs whether or not the constant is referenced, we invoke constant folding on all `static const` integer variables. Limitations: * This only works for integers, since that is all front-end constant folding applies to. A future change can/should catch circularity in constants at the IR level (and handle more types). * This only works for constants. Circular references in the definition of a global variable are harder to diagnose, but at least shouldn't result in compiler crashes. * This doesn't work across modules, or through generic specialization: anything that requires global knowledge won't be checked * fixup: missing files * fixup: review feedback
1 parent 2359921 commit 36a06f1

6 files changed

+199
-52
lines changed

source/slang/slang-check-decl.cpp

+43-11
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,24 @@ namespace Slang
775775
return true;
776776
}
777777

778+
void SemanticsVisitor::_validateCircularVarDefinition(VarDeclBase* varDecl)
779+
{
780+
// The easiest way to test if the declaration is circular is to
781+
// validate it as a constant.
782+
//
783+
// TODO: The logic here will only apply for `static const` declarations
784+
// of integer type, given that our constant folding currently only
785+
// applies to such types. A more robust fix would involve a truly
786+
// recursive walk of the AST declarations, and an even *more* robust
787+
// fix would wait until after IR linking to detect and diagnose circularity
788+
// in case it crosses module boundaries.
789+
//
790+
//
791+
if(!isScalarIntegerType(varDecl->type))
792+
return;
793+
tryConstantFoldDeclRef(DeclRef<VarDeclBase>(varDecl, nullptr), nullptr);
794+
}
795+
778796
void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
779797
{
780798
// A variable that didn't have an explicit type written must
@@ -804,6 +822,8 @@ namespace Slang
804822

805823
varDecl->initExpr = initExpr;
806824
varDecl->type.type = initExpr->type;
825+
826+
_validateCircularVarDefinition(varDecl);
807827
}
808828

809829
// If we've gone down this path, then the variable
@@ -857,11 +877,16 @@ namespace Slang
857877
{
858878
// If the variable has an explicit initial-value expression,
859879
// then we simply need to check that expression and coerce
860-
// it to the tyep of the variable.
880+
// it to the type of the variable.
861881
//
862882
initExpr = CheckTerm(initExpr);
863883
initExpr = coerce(varDecl->type.Ptr(), initExpr);
864884
varDecl->initExpr = initExpr;
885+
886+
// We need to ensure that any variable doesn't introduce
887+
// a constant with a circular definition.
888+
//
889+
_validateCircularVarDefinition(varDecl);
865890
}
866891
else
867892
{
@@ -1970,18 +1995,25 @@ namespace Slang
19701995
return (BaseTypeInfo::getInfo(baseType).flags & BaseTypeInfo::Flag::Integer) != 0;
19711996
}
19721997

1973-
void SemanticsVisitor::validateEnumTagType(Type* type, SourceLoc const& loc)
1998+
bool SemanticsVisitor::isScalarIntegerType(Type* type)
19741999
{
1975-
if(auto basicType = as<BasicExpressionType>(type))
1976-
{
1977-
// Allow the built-in integer types.
1978-
if(isIntegerBaseType(basicType->baseType))
1979-
return;
2000+
auto basicType = as<BasicExpressionType>(type);
2001+
if(!basicType)
2002+
return false;
19802003

1981-
// By default, don't allow other types to be used
1982-
// as an `enum` tag type.
1983-
}
2004+
return isIntegerBaseType(basicType->baseType);
2005+
}
19842006

2007+
void SemanticsVisitor::validateEnumTagType(Type* type, SourceLoc const& loc)
2008+
{
2009+
// Allow the built-in integer types.
2010+
//
2011+
if(isScalarIntegerType(type))
2012+
return;
2013+
2014+
// By default, don't allow other types to be used
2015+
// as an `enum` tag type.
2016+
//
19852017
getSink()->diagnose(loc, Diagnostics::invalidEnumTagType, type);
19862018
}
19872019

@@ -2177,7 +2209,7 @@ namespace Slang
21772209
// the tag value for a successor case that doesn't
21782210
// provide an explicit tag.
21792211

2180-
IntVal* explicitTagVal = TryConstantFoldExpr(explicitTagValExpr);
2212+
IntVal* explicitTagVal = tryConstantFoldExpr(explicitTagValExpr, nullptr);
21812213
if(explicitTagVal)
21822214
{
21832215
if(auto constIntVal = as<ConstantIntVal>(explicitTagVal))

source/slang/slang-check-expr.cpp

+69-34
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,9 @@ namespace Slang
669669
return m_astBuilder->create<ConstantIntVal>(expr->value);
670670
}
671671

672-
IntVal* SemanticsVisitor::TryConstantFoldExpr(
673-
InvokeExpr* invokeExpr)
672+
IntVal* SemanticsVisitor::tryConstantFoldExpr(
673+
InvokeExpr* invokeExpr,
674+
ConstantFoldingCircularityInfo* circularityInfo)
674675
{
675676
// We need all the operands to the expression
676677

@@ -707,7 +708,7 @@ namespace Slang
707708
bool allConst = true;
708709
for (auto argExpr : invokeExpr->arguments)
709710
{
710-
auto argVal = TryCheckIntegerConstantExpression(argExpr);
711+
auto argVal = tryFoldIntegerConstantExpression(argExpr, circularityInfo);
711712
if (!argVal)
712713
return nullptr;
713714

@@ -795,8 +796,53 @@ namespace Slang
795796
return result;
796797
}
797798

798-
IntVal* SemanticsVisitor::TryConstantFoldExpr(
799-
Expr* expr)
799+
bool SemanticsVisitor::_checkForCircularityInConstantFolding(
800+
Decl* decl,
801+
ConstantFoldingCircularityInfo* circularityInfo)
802+
{
803+
// TODO: If the `decl` is already on the chain of `circularityInfo`,
804+
// then we know that we are trying to recursively fold the
805+
// same declaration as part of its own definition, and we need
806+
// to diagnose that as an error.
807+
//
808+
for( auto info = circularityInfo; info; info = info->next )
809+
{
810+
if(decl == info->decl)
811+
{
812+
getSink()->diagnose(decl, Diagnostics::variableUsedInItsOwnDefinition, decl);
813+
return true;
814+
}
815+
}
816+
817+
return false;
818+
}
819+
820+
IntVal* SemanticsVisitor::tryConstantFoldDeclRef(
821+
DeclRef<VarDeclBase> const& declRef,
822+
ConstantFoldingCircularityInfo* circularityInfo)
823+
{
824+
auto decl = declRef.getDecl();
825+
826+
if(_checkForCircularityInConstantFolding(decl, circularityInfo))
827+
return nullptr;
828+
829+
// In HLSL, `static const` is used to mark compile-time constant expressions
830+
if(!decl->hasModifier<HLSLStaticModifier>())
831+
return nullptr;
832+
if(!decl->hasModifier<ConstModifier>())
833+
return nullptr;
834+
835+
auto initExpr = getInitExpr(m_astBuilder, declRef);
836+
if(!initExpr)
837+
return nullptr;
838+
839+
ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo);
840+
return tryConstantFoldExpr(initExpr, &newCircularityInfo);
841+
}
842+
843+
IntVal* SemanticsVisitor::tryConstantFoldExpr(
844+
Expr* expr,
845+
ConstantFoldingCircularityInfo* circularityInfo)
800846
{
801847
// Unwrap any "identity" expressions
802848
while (auto parenExpr = as<ParenExpr>(expr))
@@ -825,62 +871,51 @@ namespace Slang
825871
// are defined in a way that can be used as a constant expression:
826872
if(auto varRef = declRef.as<VarDeclBase>())
827873
{
828-
auto varDecl = varRef.getDecl();
829-
830-
// In HLSL, `static const` is used to mark compile-time constant expressions
831-
if(auto staticAttr = varDecl->findModifier<HLSLStaticModifier>())
832-
{
833-
if(auto constAttr = varDecl->findModifier<ConstModifier>())
834-
{
835-
// HLSL `static const` can be used as a constant expression
836-
if(auto initExpr = getInitExpr(m_astBuilder, varRef))
837-
{
838-
return TryConstantFoldExpr(initExpr);
839-
}
840-
}
841-
}
874+
return tryConstantFoldDeclRef(varRef, circularityInfo);
842875
}
843876
else if(auto enumRef = declRef.as<EnumCaseDecl>())
844877
{
845878
// The cases in an `enum` declaration can also be used as constant expressions,
846879
if(auto tagExpr = getTagExpr(m_astBuilder, enumRef))
847880
{
848-
return TryConstantFoldExpr(tagExpr);
881+
auto enumCaseDecl = enumRef.getDecl();
882+
if(_checkForCircularityInConstantFolding(enumCaseDecl, circularityInfo))
883+
return nullptr;
884+
885+
ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo);
886+
return tryConstantFoldExpr(tagExpr, &newCircularityInfo);
849887
}
850888
}
851889
}
852890

853891
if(auto castExpr = as<TypeCastExpr>(expr))
854892
{
855-
auto val = TryConstantFoldExpr(castExpr->arguments[0]);
893+
auto val = tryConstantFoldExpr(castExpr->arguments[0], circularityInfo);
856894
if(val)
857895
return val;
858896
}
859897
else if (auto invokeExpr = as<InvokeExpr>(expr))
860898
{
861-
auto val = TryConstantFoldExpr(invokeExpr);
899+
auto val = tryConstantFoldExpr(invokeExpr, circularityInfo);
862900
if (val)
863901
return val;
864902
}
865903

866904
return nullptr;
867905
}
868906

869-
IntVal* SemanticsVisitor::TryCheckIntegerConstantExpression(Expr* exp)
907+
IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression(
908+
Expr* expr,
909+
ConstantFoldingCircularityInfo* circularityInfo)
870910
{
871911
// Check if type is acceptable for an integer constant expression
872-
if(auto basicType = as<BasicExpressionType>(exp->type.type))
873-
{
874-
if(!isIntegerBaseType(basicType->baseType))
875-
return nullptr;
876-
}
877-
else
878-
{
912+
//
913+
if(!isScalarIntegerType(expr->type))
879914
return nullptr;
880-
}
881915

882916
// Consider operations that we might be able to constant-fold...
883-
return TryConstantFoldExpr(exp);
917+
//
918+
return tryConstantFoldExpr(expr, circularityInfo);
884919
}
885920

886921
IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, DiagnosticSink* sink)
@@ -894,7 +929,7 @@ namespace Slang
894929
// No need to issue further errors if the type coercion failed.
895930
if(IsErrorExpr(expr)) return nullptr;
896931

897-
auto result = TryCheckIntegerConstantExpression(expr);
932+
auto result = tryFoldIntegerConstantExpression(expr, nullptr);
898933
if (!result && sink)
899934
{
900935
sink->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant);
@@ -915,7 +950,7 @@ namespace Slang
915950
// No need to issue further errors if the type coercion failed.
916951
if(IsErrorExpr(expr)) return nullptr;
917952

918-
auto result = TryConstantFoldExpr(expr);
953+
auto result = tryConstantFoldExpr(expr, nullptr);
919954
if (!result)
920955
{
921956
getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant);

source/slang/slang-check-impl.h

+57-7
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,18 @@ namespace Slang
491491
// Capture the "base" expression in case this is a member reference
492492
Expr* GetBaseExpr(Expr* expr);
493493

494+
/// Validate a declaration to ensure that it doesn't introduce a circularly-defined constant
495+
///
496+
/// Circular definition in a constant may lead to infinite looping or stack overflow in
497+
/// the compiler, so it needs to be protected against.
498+
///
499+
/// Note that this function does *not* protect against circular definitions in general,
500+
/// and a program that indirectly initializes a global variable using its own value (e.g.,
501+
/// by calling a function that indirectly reads the variable) will be allowed and then
502+
/// exhibit undefined behavior at runtime.
503+
///
504+
void _validateCircularVarDefinition(VarDeclBase* varDecl);
505+
494506
public:
495507

496508
bool ValuesAreEqual(
@@ -778,6 +790,9 @@ namespace Slang
778790

779791
bool isIntegerBaseType(BaseType baseType);
780792

793+
/// Is `type` a scalar integer type.
794+
bool isScalarIntegerType(Type* type);
795+
781796
// Validate that `type` is a suitable type to use
782797
// as the tag type for an `enum`
783798
void validateEnumTagType(Type* type, SourceLoc const& loc);
@@ -827,15 +842,50 @@ namespace Slang
827842
return getNamePool()->getName(text);
828843
}
829844

830-
IntVal* TryConstantFoldExpr(
831-
InvokeExpr* invokeExpr);
845+
/// Helper type to detect and catch circular definitions when folding constants,
846+
/// to prevent the compiler from going into infinite loops or overflowing the stack.
847+
struct ConstantFoldingCircularityInfo
848+
{
849+
ConstantFoldingCircularityInfo(
850+
Decl* decl,
851+
ConstantFoldingCircularityInfo* next)
852+
: decl(decl)
853+
, next(next)
854+
{}
855+
856+
/// A declaration whose value is contributing to the constant being folded
857+
Decl* decl = nullptr;
858+
859+
/// The rest of the links in the chain of declarations being folded
860+
ConstantFoldingCircularityInfo* next = nullptr;
861+
};
862+
863+
/// Try to apply front-end constant folding to determine the value of `invokeExpr`.
864+
IntVal* tryConstantFoldExpr(
865+
InvokeExpr* invokeExpr,
866+
ConstantFoldingCircularityInfo* circularityInfo);
867+
868+
/// Try to apply front-end constant folding to determine the value of `expr`.
869+
IntVal* tryConstantFoldExpr(
870+
Expr* expr,
871+
ConstantFoldingCircularityInfo* circularityInfo);
832872

833-
IntVal* TryConstantFoldExpr(
834-
Expr* expr);
873+
bool _checkForCircularityInConstantFolding(
874+
Decl* decl,
875+
ConstantFoldingCircularityInfo* circularityInfo);
835876

836-
// Try to check an integer constant expression, either returning the value,
837-
// or NULL if the expression isn't recognized as a constant.
838-
IntVal* TryCheckIntegerConstantExpression(Expr* exp);
877+
/// Try to resolve a compile-time constant `IntVal` from the given `declRef`.
878+
IntVal* tryConstantFoldDeclRef(
879+
DeclRef<VarDeclBase> const& declRef,
880+
ConstantFoldingCircularityInfo* circularityInfo);
881+
882+
/// Try to extract the value of an integer constant expression, either
883+
/// returning the `IntVal` value, or null if the expression isn't recognized
884+
/// as an integer constant.
885+
///
886+
IntVal* tryFoldIntegerConstantExpression(
887+
Expr* expr,
888+
ConstantFoldingCircularityInfo* circularityInfo);
839889

840890
// Enforce that an expression resolves to an integer constant, and get its value
841891
IntVal* CheckIntegerConstantExpression(Expr* inExpr);

source/slang/slang-diagnostic-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only
299299
// TODO: need to assign numbers to all these extra diagnostics...
300300
DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.")
301301
DIAGNOSTIC(39999, Fatal, localVariableUsedBeforeDeclared, "local variable '$0' is being used before its declaration.")
302+
DIAGNOSTIC(39999, Error, variableUsedInItsOwnDefinition, "the initial-value expression for variable '$0' depends on the value of the variable itself")
302303

303304
// 304xx: generics
304305
DIAGNOSTIC(30400, Error, genericTypeNeedsArgs, "generic type '$0' used without argument")

tests/diagnostics/gh-1374.slang

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// gh-1374.slang
2+
3+
// Test a `static` variable wwith a definition that refers to itself
4+
5+
//DIAGNOSTIC_TEST:REFLECTION:-stage compute -entry main -target hlsl
6+
7+
struct S
8+
{
9+
static const int kVal = kVal;
10+
11+
static const int kInf = kInf + 1;
12+
13+
static const int kA = kB;
14+
static const int kB = kA;
15+
}
16+
17+
[numthreads(1, 1, 1)]
18+
void main(
19+
uint3 dispatchThreadID : SV_DispatchThreadID)
20+
{
21+
}
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
result code = 1
2+
standard error = {
3+
tests/diagnostics/gh-1374.slang(9): error 39999: the initial-value expression for variable 'kVal' depends on the value of the variable itself
4+
tests/diagnostics/gh-1374.slang(11): error 39999: the initial-value expression for variable 'kInf' depends on the value of the variable itself
5+
tests/diagnostics/gh-1374.slang(14): error 39999: the initial-value expression for variable 'kB' depends on the value of the variable itself
6+
}
7+
standard output = {
8+
}

0 commit comments

Comments
 (0)