Skip to content

Commit b4c4dc9

Browse files
author
Tim Foley
authored
Add support for default parameter values in IR codegen (shader-slang#459)
Fixes shader-slang#61 When lowering from AST to IR, if a call site doesn't supply an argument expression for each of the parameters to the callee, then use the default value expressions (stored as the "initializer" of the parameter decl) for each omitted parameter. This relies on the front-end to have already checked the call site for validity. Along the way I also cleaned up some of the checking of parameter declarations so that it is more like the checking of ordinary variable declarations (although the code is not yet shared). I also cleaned out some dead cases in the lowering logic for when we don't actually have a declaration available for a callee (these would only matter if we supported functions as first-class values). I added a simple test case to confirm that call sites both with and without the optional parameter work as expected. The strategy in this change is extremely simplistic, and might only be appropriate for default parameter value expressions that are compile-time constants (which should be the 99% case). This may require a major overhaul if we decide to handle default parameter values differently (e.g., by generating extra functions to ensure that the separate compilation story is what we want). Another issue that could change a lot of this logic would be if we start to support by-name parameters at call sites, since we could no longer assume that the argument and parameter lists align one-to-one (with the argument list possibly being shorter). Any work to add more flexible argument passing conventions would need to build a suitable structure to map from arguments to parameters, or vice-versa.
1 parent 184dc5c commit b4c4dc9

6 files changed

+110
-52
lines changed

source/slang/check.cpp

+29-9
Original file line numberDiff line numberDiff line change
@@ -2820,18 +2820,38 @@ namespace Slang
28202820
// Nothing to do
28212821
}
28222822

2823-
void visitParamDecl(ParamDecl* para)
2823+
void visitParamDecl(ParamDecl* paramDecl)
28242824
{
2825-
// TODO: This needs to bottleneck through the common variable checks
2825+
// TODO: This logic should be shared with the other cases of
2826+
// variable declarations. The main reason I am not doing it
2827+
// yet is that we use a `ParamDecl` with a null type as a
2828+
// special case in attribute declarations, and that could
2829+
// trip up the ordinary variable checks.
28262830

2827-
if(para->type.exp)
2831+
auto typeExpr = paramDecl->type;
2832+
if(typeExpr.exp)
28282833
{
2829-
para->type = CheckUsableType(para->type);
2830-
2831-
if (para->type.Equals(getSession()->getVoidType()))
2832-
{
2833-
getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid);
2834-
}
2834+
typeExpr = CheckUsableType(typeExpr);
2835+
paramDecl->type = typeExpr;
2836+
}
2837+
2838+
// The "initializer" expression for a parameter represents
2839+
// a default argument value to use if an explicit one is
2840+
// not supplied.
2841+
if(auto initExpr = paramDecl->initExpr)
2842+
{
2843+
// We must check the expression and coerce it to the
2844+
// actual type of the parameter.
2845+
//
2846+
initExpr = CheckExpr(initExpr);
2847+
initExpr = Coerce(typeExpr.type, initExpr);
2848+
paramDecl->initExpr = initExpr;
2849+
2850+
// TODO: a default argument expression needs to
2851+
// conform to other constraints to be valid.
2852+
// For example, it should not be allowed to refer
2853+
// to other parameters of the same function (or maybe
2854+
// only the parameters to its left...).
28352855
}
28362856
}
28372857

source/slang/diagnostic-defs.h

-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for t
176176
DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
177177
DIAGNOSTIC(30015, Error, undefinedIdentifier, "'$0': undefined identifier.")
178178
DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
179-
DIAGNOSTIC(30016, Error, parameterCannotBeVoid, "'void' can not be parameter type.")
180179
DIAGNOSTIC(30017, Error, componentNotAccessibleFromShader, "component '$0' is not accessible from shader '$1'.")
181180
DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'")
182181
DIAGNOSTIC(30020, Error, importOperatorReturnTypeMismatch, "import operator should return '$1', but the expression has type '$0''. do you forget 'project'?")

source/slang/lower-to-ir.cpp

+49-42
Original file line numberDiff line numberDiff line change
@@ -1417,21 +1417,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
14171417
SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression");
14181418
}
14191419

1420-
// Add arguments that appeared directly in an argument list
1421-
// to the list of argument values for a call.
1422-
void addDirectCallArgs(
1423-
InvokeExpr* expr,
1424-
List<IRInst*>* ioArgs)
1425-
{
1426-
for( auto arg : expr->Arguments )
1427-
{
1428-
// TODO: Need to handle case of l-value arguments,
1429-
// when they are matched to `out` or `in out` parameters.
1430-
auto loweredArg = lowerRValueExpr(context, arg);
1431-
addArgs(context, ioArgs, loweredArg);
1432-
}
1433-
}
1434-
14351420
// After a call to a function with `out` or `in out`
14361421
// parameters, we may need to copy data back into
14371422
// the l-value locations used for output arguments.
@@ -1452,18 +1437,44 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
14521437
List<OutArgumentFixup>* ioFixups)
14531438
{
14541439
UInt argCount = expr->Arguments.Count();
1455-
UInt argIndex = 0;
1440+
UInt argCounter = 0;
14561441
for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
14571442
{
1458-
if (argIndex >= argCount)
1459-
{
1460-
// The remaining parameters must be defaulted...
1461-
break;
1462-
}
1463-
14641443
auto paramDecl = paramDeclRef.getDecl();
14651444
RefPtr<Type> paramType = lowerSimpleType(context, GetType(paramDeclRef));
1466-
auto argExpr = expr->Arguments[argIndex++];
1445+
1446+
UInt argIndex = argCounter++;
1447+
RefPtr<Expr> argExpr;
1448+
if(argIndex < argCount)
1449+
{
1450+
argExpr = expr->Arguments[argIndex];
1451+
}
1452+
else
1453+
{
1454+
// We have run out of arguments supplied at the call site,
1455+
// but there are still parameters remaining. This must mean
1456+
// that these parameters have default argument expressions
1457+
// associated with them.
1458+
argExpr = getInitExpr(paramDeclRef);
1459+
1460+
// Assert that such an expression must have been present.
1461+
SLANG_ASSERT(argExpr);
1462+
1463+
// TODO: The approach we are taking here to default arguments
1464+
// is simplistic, and has consequences for the front-end as
1465+
// well as binary serializatiojn of modules.
1466+
//
1467+
// We could consider some more refined approaches where, e.g.,
1468+
// functions with default arguments generate multiple IR-level
1469+
// functions, that compute and provide the default values.
1470+
//
1471+
// Alternatively, each parameter with defaults could be generated
1472+
// into its own callable function that provides the default value,
1473+
// so that calling modules can call into a pre-generated function.
1474+
//
1475+
// Each of these options involves trade-offs, and we need to
1476+
// make a conscious decision at some point.
1477+
}
14671478

14681479
if (paramDecl->HasModifier<OutModifier>()
14691480
|| paramDecl->HasModifier<InOutModifier>())
@@ -1543,8 +1554,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
15431554
}
15441555
else
15451556
{
1546-
SLANG_UNEXPECTED("shouldn't relaly happen");
1547-
UNREACHABLE(addDirectCallArgs(expr, ioArgs));
1557+
SLANG_UNEXPECTED("callee was not a callable decl");
15481558
}
15491559
}
15501560

@@ -1696,24 +1706,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
16961706
return result;
16971707
}
16981708

1699-
// The default case is to assume that we just have
1700-
// an ordinary expression, and can lower it as such.
1701-
LoweredValInfo funcVal = lowerRValueExpr(context, expr->FunctionExpr);
1702-
1703-
// Now we add any direct arguments from the call expression itself.
1704-
addDirectCallArgs(expr, &irArgs);
1705-
1706-
// Delegate to the logic for invoking a value.
1707-
auto result = emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer());
1708-
1709-
// TODO: because of the nature of how the `emitCallToVal` case works
1710-
// right now, we don't have information on in/out parameters, and
1711-
// so we can't collect info to apply fixups.
1709+
// TODO: In this case we should be emitting code for the callee as
1710+
// an ordinary expression, then emitting the arguments according
1711+
// to the type information on the callee (e.g., which paameters
1712+
// are `out` or `inout`, and then finally emitting the `call`
1713+
// instruciton.
17121714
//
1713-
// Once we have a better representation for function types, though,
1714-
// this should be fixable.
1715-
1716-
return result;
1715+
// We don't currently have the case of emitting arguments according
1716+
// to function type info (instead of declaration info), and really
1717+
// this case can't occur unless we start adding first-class functions
1718+
// to the source language.
1719+
//
1720+
// For now we just bail out with an error.
1721+
//
1722+
SLANG_UNEXPECTED("could not resolve target declaration for call");
1723+
UNREACHABLE_RETURN(LoweredValInfo());
17171724
}
17181725

17191726
LoweredValInfo subscriptValue(

tests/compute/default-parameter.slang

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
5+
6+
RWStructuredBuffer<int> outputBuffer;
7+
8+
int helper(int val, int a = 16)
9+
{
10+
return val + a;
11+
}
12+
13+
int test(int val)
14+
{
15+
return helper(val) + helper(val, 256);
16+
}
17+
18+
[numthreads(4, 1, 1)]
19+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
20+
{
21+
int inVal = (int) dispatchThreadID.x;
22+
int outVal = test(inVal);
23+
outputBuffer[dispatchThreadID.x] = outVal;
24+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
110
2+
112
3+
114
4+
116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
110
2+
112
3+
114
4+
116

0 commit comments

Comments
 (0)