Skip to content

Commit ecc5a39

Browse files
authored
Do recursive function checks early during IR linking (#5777)
1 parent d4136c9 commit ecc5a39

6 files changed

+133
-111
lines changed

source/slang/slang-emit.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include "slang-ir-autodiff.h"
2323
#include "slang-ir-bind-existentials.h"
2424
#include "slang-ir-byte-address-legalize.h"
25-
#include "slang-ir-check-recursive-type.h"
25+
#include "slang-ir-check-recursion.h"
2626
#include "slang-ir-check-shader-parameter-type.h"
2727
#include "slang-ir-check-unsupported-inst.h"
2828
#include "slang-ir-cleanup-void.h"
@@ -884,6 +884,7 @@ Result linkAndOptimizeIR(
884884
if (targetProgram->getOptionSet().shouldRunNonEssentialValidation())
885885
{
886886
checkForRecursiveTypes(irModule, sink);
887+
checkForRecursiveFunctions(codeGenContext->getTargetReq(), irModule, sink);
887888

888889
// For some targets, we are more restrictive about what types are allowed
889890
// to be used as shader parameters in ConstantBuffer/ParameterBlock.
+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#include "slang-ir-check-recursion.h"
2+
3+
#include "slang-ir-util.h"
4+
5+
namespace Slang
6+
{
7+
bool checkTypeRecursionImpl(
8+
HashSet<IRInst*>& checkedTypes,
9+
HashSet<IRInst*>& stack,
10+
IRInst* type,
11+
IRInst* field,
12+
DiagnosticSink* sink)
13+
{
14+
auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool
15+
{
16+
if (!stack.add(elementType))
17+
{
18+
sink->diagnose(field ? field : type, Diagnostics::recursiveType, type);
19+
return false;
20+
}
21+
if (checkedTypes.add(elementType))
22+
checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink);
23+
stack.remove(elementType);
24+
return true;
25+
};
26+
if (auto arrayType = as<IRArrayTypeBase>(type))
27+
{
28+
return visitElementType(arrayType->getElementType(), field);
29+
}
30+
else if (auto structType = as<IRStructType>(type))
31+
{
32+
for (auto sfield : structType->getFields())
33+
if (!visitElementType(sfield->getFieldType(), sfield))
34+
return false;
35+
}
36+
return true;
37+
}
38+
39+
void checkTypeRecursion(HashSet<IRInst*>& checkedTypes, IRInst* type, DiagnosticSink* sink)
40+
{
41+
HashSet<IRInst*> stack;
42+
if (checkedTypes.add(type))
43+
{
44+
stack.add(type);
45+
checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink);
46+
}
47+
}
48+
49+
void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink)
50+
{
51+
HashSet<IRInst*> checkedTypes;
52+
for (auto globalInst : module->getGlobalInsts())
53+
{
54+
switch (globalInst->getOp())
55+
{
56+
case kIROp_StructType:
57+
{
58+
checkTypeRecursion(checkedTypes, globalInst, sink);
59+
}
60+
break;
61+
default:
62+
break;
63+
}
64+
}
65+
}
66+
67+
bool checkFunctionRecursionImpl(
68+
HashSet<IRFunc*>& checkedFuncs,
69+
HashSet<IRFunc*>& callStack,
70+
IRFunc* func,
71+
DiagnosticSink* sink)
72+
{
73+
for (auto block : func->getBlocks())
74+
{
75+
for (auto inst : block->getChildren())
76+
{
77+
auto callInst = as<IRCall>(inst);
78+
if (!callInst)
79+
continue;
80+
auto callee = as<IRFunc>(callInst->getCallee());
81+
if (!callee)
82+
continue;
83+
if (!callStack.add(callee))
84+
{
85+
sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
86+
return false;
87+
}
88+
if (checkedFuncs.add(callee))
89+
checkFunctionRecursionImpl(checkedFuncs, callStack, callee, sink);
90+
callStack.remove(callee);
91+
}
92+
}
93+
return true;
94+
}
95+
96+
void checkFunctionRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
97+
{
98+
HashSet<IRFunc*> callStack;
99+
if (checkedFuncs.add(func))
100+
{
101+
callStack.add(func);
102+
checkFunctionRecursionImpl(checkedFuncs, callStack, func, sink);
103+
}
104+
}
105+
106+
void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
107+
{
108+
HashSet<IRFunc*> checkedFuncsForRecursionDetection;
109+
for (auto globalInst : module->getGlobalInsts())
110+
{
111+
switch (globalInst->getOp())
112+
{
113+
case kIROp_Func:
114+
if (!isCPUTarget(target))
115+
checkFunctionRecursion(
116+
checkedFuncsForRecursionDetection,
117+
as<IRFunc>(globalInst),
118+
sink);
119+
break;
120+
default:
121+
break;
122+
}
123+
}
124+
}
125+
126+
} // namespace Slang

source/slang/slang-ir-check-recursive-type.h source/slang/slang-ir-check-recursion.h

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ namespace Slang
44
{
55
struct IRModule;
66
class DiagnosticSink;
7+
class TargetRequest;
78

89
void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink);
10+
11+
void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink);
12+
913
} // namespace Slang

source/slang/slang-ir-check-recursive-type.cpp

-65
This file was deleted.

source/slang/slang-ir-check-unsupported-inst.cpp

-44
Original file line numberDiff line numberDiff line change
@@ -5,46 +5,6 @@
55

66
namespace Slang
77
{
8-
bool isCPUTarget(TargetRequest* targetReq);
9-
10-
bool checkRecursionImpl(
11-
HashSet<IRFunc*>& checkedFuncs,
12-
HashSet<IRFunc*>& callStack,
13-
IRFunc* func,
14-
DiagnosticSink* sink)
15-
{
16-
for (auto block : func->getBlocks())
17-
{
18-
for (auto inst : block->getChildren())
19-
{
20-
auto callInst = as<IRCall>(inst);
21-
if (!callInst)
22-
continue;
23-
auto callee = as<IRFunc>(callInst->getCallee());
24-
if (!callee)
25-
continue;
26-
if (!callStack.add(callee))
27-
{
28-
sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
29-
return false;
30-
}
31-
if (checkedFuncs.add(callee))
32-
checkRecursionImpl(checkedFuncs, callStack, callee, sink);
33-
callStack.remove(callee);
34-
}
35-
}
36-
return true;
37-
}
38-
39-
void checkRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
40-
{
41-
HashSet<IRFunc*> callStack;
42-
if (checkedFuncs.add(func))
43-
{
44-
callStack.add(func);
45-
checkRecursionImpl(checkedFuncs, callStack, func, sink);
46-
}
47-
}
488

499
void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink)
5010
{
@@ -65,8 +25,6 @@ void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* s
6525

6626
void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
6727
{
68-
HashSet<IRFunc*> checkedFuncsForRecursionDetection;
69-
7028
for (auto globalInst : module->getGlobalInsts())
7129
{
7230
switch (globalInst->getOp())
@@ -84,8 +42,6 @@ void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSin
8442
break;
8543
}
8644
case kIROp_Func:
87-
if (!isCPUTarget(target))
88-
checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink);
8945
checkUnsupportedInst(target, as<IRFunc>(globalInst), sink);
9046
break;
9147
case kIROp_Generic:

source/slang/slang-lower-to-ir.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "slang-ir-autodiff.h"
1010
#include "slang-ir-bit-field-accessors.h"
1111
#include "slang-ir-check-differentiability.h"
12-
#include "slang-ir-check-recursive-type.h"
12+
#include "slang-ir-check-recursion.h"
1313
#include "slang-ir-clone.h"
1414
#include "slang-ir-constexpr.h"
1515
#include "slang-ir-dce.h"

0 commit comments

Comments
 (0)