Skip to content

Commit f94b2f7

Browse files
authored
Allow loop counters to be used as constexpr arguments. (shader-slang#3139)
* Allow loop counters to be used as constexpr arguments. * Fix. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent 4de3d9b commit f94b2f7

10 files changed

+260
-99
lines changed

source/slang/slang-ir-autodiff-cfg-norm.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,8 @@ void normalizeCFG(
790790
disableIRValidationAtInsert();
791791
constructSSA(module, func);
792792
enableIRValidationAtInsert();
793+
794+
module->invalidateAnalysisForInst(func);
793795
#if _DEBUG
794796
validateIRInst(maybeFindOuterGeneric(func));
795797
#endif

source/slang/slang-ir-constexpr.cpp

+183-95
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "slang-ir.h"
55
#include "slang-ir-insts.h"
6+
#include "slang-ir-dominators.h"
67

78
namespace Slang {
89

@@ -74,6 +75,7 @@ bool opCanBeConstExpr(IROp op)
7475
case kIROp_IntLit:
7576
case kIROp_FloatLit:
7677
case kIROp_BoolLit:
78+
case kIROp_Param:
7779
case kIROp_Add:
7880
case kIROp_Sub:
7981
case kIROp_Mul:
@@ -142,13 +144,35 @@ bool opCanBeConstExpr(IROp op)
142144
}
143145
}
144146

145-
bool opCanBeConstExpr(IRInst* value)
147+
bool opCanBeConstExprByForwardPass(IRInst* value)
146148
{
147149
// TODO: realistically need to special-case `call`
148150
// operations here, so that we check whether the
149151
// callee function is fixed/known, and if it is
150152
// whether it has been declared as constant-foldable
153+
if (value->getOp() == kIROp_Param)
154+
return false;
155+
return opCanBeConstExpr(value->getOp());
156+
}
157+
158+
IRLoop* isLoopPhi(IRParam* param)
159+
{
160+
IRBlock* bb = cast<IRBlock>(param->getParent());
161+
for (auto pred : bb->getPredecessors())
162+
{
163+
auto loop = as<IRLoop>(pred->getTerminator());
164+
if (loop)
165+
{
166+
return loop;
167+
}
168+
}
169+
return nullptr;
170+
}
151171

172+
bool opCanBeConstExprByBackwardPass(IRInst* value)
173+
{
174+
if (value->getOp() == kIROp_Param)
175+
return isLoopPhi(as<IRParam>(value));
152176
return opCanBeConstExpr(value->getOp());
153177
}
154178

@@ -159,6 +183,110 @@ void markConstExpr(
159183
Slang::markConstExpr(context->getBuilder(), value);
160184
}
161185

186+
void maybeAddToWorkList(
187+
PropagateConstExprContext* context,
188+
IRInst* gv)
189+
{
190+
if (!context->onWorkList.contains(gv))
191+
{
192+
context->workList.add(gv);
193+
context->onWorkList.add(gv);
194+
}
195+
}
196+
197+
bool maybeMarkConstExprBackwardPass(
198+
PropagateConstExprContext* context,
199+
IRInst* value)
200+
{
201+
if (isConstExpr(value))
202+
return false;
203+
204+
if (!opCanBeConstExprByBackwardPass(value))
205+
return false;
206+
207+
markConstExpr(context, value);
208+
209+
// TODO: we should only allow function parameters to be
210+
// changed to be `constexpr` when we are compiling "application"
211+
// code, and not library code.
212+
// (Or eventually we'd have a rule that only non-`public` symbols
213+
// can have this kind of propagation applied).
214+
215+
if (value->getOp() == kIROp_Param)
216+
{
217+
auto param = (IRParam*)value;
218+
auto block = (IRBlock*)param->parent;
219+
auto code = block->getParent();
220+
221+
if (block == code->getFirstBlock())
222+
{
223+
// We've just changed a function parameter to
224+
// be `constexpr`. We need to remember that
225+
// fact so taht we can mark callers of this
226+
// function as `constexpr` themselves.
227+
228+
for (auto u = code->firstUse; u; u = u->nextUse)
229+
{
230+
auto user = u->getUser();
231+
232+
switch (user->getOp())
233+
{
234+
case kIROp_Call:
235+
{
236+
auto inst = (IRCall*)user;
237+
auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
238+
maybeAddToWorkList(context, caller);
239+
}
240+
break;
241+
242+
default:
243+
break;
244+
}
245+
}
246+
}
247+
}
248+
249+
return true;
250+
}
251+
252+
// Produce an estimate on whether a loop is unrollable, by checking
253+
// if there is at least one exit path where all the conditions along
254+
// the control path has a constexpr condition.
255+
bool isUnrollableLoop(IRLoop* loop)
256+
{
257+
// A loop is unrollable if all exit conditions are constexpr.
258+
auto breakBlock = loop->getBreakBlock();
259+
auto func = getParentFunc(loop);
260+
auto domTree = loop->getModule()->findOrCreateDominatorTree(func);
261+
List<IRBlock*> workList;
262+
bool result = false;
263+
for (auto pred : breakBlock->getPredecessors())
264+
{
265+
workList.clear();
266+
workList.add(pred);
267+
for (Index i = 0; i < workList.getCount(); i++)
268+
{
269+
auto block = workList[i];
270+
if (auto ifElse = as<IRConditionalBranch>(block->getTerminator()))
271+
{
272+
if (!isConstExpr(ifElse->getCondition()))
273+
return false;
274+
}
275+
else if (auto switchInst = as<IRSwitch>(block->getTerminator()))
276+
{
277+
if (!isConstExpr(ifElse->getCondition()))
278+
return false;
279+
}
280+
auto idom = domTree->getImmediateDominator(block);
281+
if (idom && idom != loop->getParent())
282+
workList.add(idom);
283+
}
284+
// We found at least one exit path that is constexpr,
285+
// we will regard this loop as unrollable.
286+
result = true;
287+
}
288+
return result;
289+
}
162290

163291
// Propagate `constexpr`-ness in a forward direction, from the
164292
// operands of an instruction to the instruction itself.
@@ -179,7 +307,7 @@ bool propagateConstExprForward(
179307
continue;
180308

181309
// Is the operation one that we can actually make be constexpr?
182-
if(!opCanBeConstExpr(ii))
310+
if(!opCanBeConstExprByForwardPass(ii))
183311
continue;
184312

185313
// Are all arguments `constexpr`?
@@ -211,71 +339,6 @@ bool propagateConstExprForward(
211339
}
212340
}
213341

214-
void maybeAddToWorkList(
215-
PropagateConstExprContext* context,
216-
IRInst* gv)
217-
{
218-
if( !context->onWorkList.contains(gv) )
219-
{
220-
context->workList.add(gv);
221-
context->onWorkList.add(gv);
222-
}
223-
}
224-
225-
bool maybeMarkConstExpr(
226-
PropagateConstExprContext* context,
227-
IRInst* value)
228-
{
229-
if(isConstExpr(value))
230-
return false;
231-
232-
if(!opCanBeConstExpr(value))
233-
return false;
234-
235-
markConstExpr(context, value);
236-
237-
// TODO: we should only allow function parameters to be
238-
// changed to be `constexpr` when we are compiling "application"
239-
// code, and not library code.
240-
// (Or eventually we'd have a rule that only non-`public` symbols
241-
// can have this kind of propagation applied).
242-
243-
if(value->getOp() == kIROp_Param)
244-
{
245-
auto param = (IRParam*) value;
246-
auto block = (IRBlock*) param->parent;
247-
auto code = block->getParent();
248-
249-
if(block == code->getFirstBlock())
250-
{
251-
// We've just changed a function parameter to
252-
// be `constexpr`. We need to remember that
253-
// fact so taht we can mark callers of this
254-
// function as `constexpr` themselves.
255-
256-
for( auto u = code->firstUse; u; u = u->nextUse )
257-
{
258-
auto user = u->getUser();
259-
260-
switch( user->getOp() )
261-
{
262-
case kIROp_Call:
263-
{
264-
auto inst = (IRCall*) user;
265-
auto caller = as<IRGlobalValueWithCode>(inst->getParent()->getParent());
266-
maybeAddToWorkList(context, caller);
267-
}
268-
break;
269-
270-
default:
271-
break;
272-
}
273-
}
274-
}
275-
}
276-
277-
return true;
278-
}
279342

280343
// Propagate `constexpr`-ness in a backward direction, from an instruction
281344
// to its operands.
@@ -312,10 +375,10 @@ bool propagateConstExprBackward(
312375
if(isConstExpr(arg))
313376
continue;
314377

315-
if(!opCanBeConstExpr(arg))
378+
if(!opCanBeConstExprByBackwardPass(arg))
316379
continue;
317380

318-
if( maybeMarkConstExpr(context, arg) )
381+
if( maybeMarkConstExprBackwardPass(context, arg) )
319382
{
320383
changedThisIteration = true;
321384
}
@@ -383,7 +446,7 @@ bool propagateConstExprBackward(
383446

384447
if(isConstExpr(param))
385448
{
386-
if(maybeMarkConstExpr(context, arg))
449+
if(maybeMarkConstExprBackwardPass(context, arg))
387450
{
388451
changedThisIteration = true;
389452
}
@@ -411,7 +474,7 @@ bool propagateConstExprBackward(
411474
auto arg = callInst->getOperand(firstCallArg + pp);
412475
if( isConstExpr(paramType) )
413476
{
414-
if( maybeMarkConstExpr(context, arg) )
477+
if(maybeMarkConstExprBackwardPass(context, arg) )
415478
{
416479
changedThisIteration = true;
417480
}
@@ -439,15 +502,14 @@ bool propagateConstExprBackward(
439502

440503
for(auto pred : bb->getPredecessors())
441504
{
442-
auto terminator = pred->getLastInst();
443-
if(terminator->getOp() != kIROp_unconditionalBranch)
505+
auto terminator = as<IRUnconditionalBranch>(pred->getLastInst());
506+
if(!terminator)
444507
continue;
445508

446-
UInt operandIndex = paramIndex + 1;
447-
SLANG_RELEASE_ASSERT(operandIndex < terminator->getOperandCount());
509+
SLANG_RELEASE_ASSERT(paramIndex < terminator->getArgCount());
448510

449-
auto operand = terminator->getOperand(operandIndex);
450-
if( maybeMarkConstExpr(context, operand) )
511+
auto operand = terminator->getArg(paramIndex);
512+
if(maybeMarkConstExprBackwardPass(context, operand) )
451513
{
452514
changedThisIteration = true;
453515
}
@@ -463,7 +525,6 @@ bool propagateConstExprBackward(
463525
anyChanges = true;
464526
}
465527
}
466-
467528
// Validate use of `constexpr` within a function (in particular,
468529
// diagnose places where a value that must be contexpr depends
469530
// on a value that cannot be)
@@ -484,9 +545,32 @@ void validateConstExpr(
484545
for( UInt aa = 0; aa < argCount; ++aa )
485546
{
486547
auto arg = ii->getOperand(aa);
487-
488-
if( !isConstExpr(arg) )
548+
bool shouldDiagnose = !isConstExpr(arg);
549+
if (!shouldDiagnose)
550+
{
551+
if (auto param = as<IRParam>(arg))
552+
{
553+
if (IRLoop * loopInst = isLoopPhi(param))
554+
{
555+
// If the param is a phi node in a loop that
556+
// does not depend on non-constexpr values, we
557+
// can make it constexpr by force unrolling the
558+
// loop, if the loop is unrollable.
559+
if (isUnrollableLoop(loopInst))
560+
{
561+
if (!loopInst->findDecoration<IRForceUnrollDecoration>())
562+
{
563+
context->getBuilder()->addLoopForceUnrollDecoration(loopInst, 0);
564+
}
565+
continue;
566+
}
567+
shouldDiagnose = true;
568+
}
569+
}
570+
}
571+
if (shouldDiagnose)
489572
{
573+
490574
// Diagnose the failure.
491575

492576
context->getSink()->diagnose(ii->sourceLoc, Diagnostics::needCompileTimeConstant);
@@ -499,6 +583,24 @@ void validateConstExpr(
499583
}
500584
}
501585

586+
void propagateInFunc(PropagateConstExprContext* context, IRGlobalValueWithCode* code)
587+
{
588+
for (;;)
589+
{
590+
bool anyChange = false;
591+
if (propagateConstExprForward(context, code))
592+
{
593+
anyChange = true;
594+
}
595+
if (propagateConstExprBackward(context, code))
596+
{
597+
anyChange = true;
598+
}
599+
if (!anyChange)
600+
break;
601+
}
602+
}
603+
502604
void propagateConstExpr(
503605
IRModule* module,
504606
DiagnosticSink* sink)
@@ -548,21 +650,7 @@ void propagateConstExpr(
548650
case kIROp_GlobalVar:
549651
{
550652
IRGlobalValueWithCode* code = (IRGlobalValueWithCode*) gv;
551-
552-
for( ;;)
553-
{
554-
bool anyChange = false;
555-
if( propagateConstExprForward(&context, code) )
556-
{
557-
anyChange = true;
558-
}
559-
if( propagateConstExprBackward(&context, code) )
560-
{
561-
anyChange = true;
562-
}
563-
if(!anyChange)
564-
break;
565-
}
653+
propagateInFunc(&context, code);
566654
}
567655
break;
568656
}

0 commit comments

Comments
 (0)