Skip to content

Commit f9d99fd

Browse files
author
Tim Foley
authored
Initial support for user-defined initializer/constructor declarations (shader-slang#1233)
The basic idea is that the user can write: ```hlsl struct MyThing { int a; float b; __init(int x, float y) { a = x; b = y; } } ``` and after that point, they can create an intstance of their `MyThing` type as simply as `MyThing(123, 4.56f)`. There was already a large amount of infrastructure laying around that is shared between ininitializers and ordinary functions, so enabling this feature mostly amounted to tying up some loose ends: * In the parser, make sure to properly push/pop the scope for an `__init` (or `__subscript`) declaration, so parameters would be visible to the body * In semantic checking, make sure that declaration "header" checking properly bottlenecks all the function-like cases into a base routine * In semantic checking, make sure that the logic for checking function bodies applies to every `FunctionDeclBase` with a body, and not just `FuncDecl`s * Update semeantic checking for statements to allow for any `FunctionDeclBase` as the parent declaration, not just a `FuncDecl` * In lookup, treat the `this` parameter of an `__init` (well, not actually a *parameter* in this case) as being mutable, just like for a `[mutating]` method * In IR codegen, don't just assume that all `__init`s are intrinsics, and narrow the scope of that hack to just `__init`s without bodies * In IR codegen, detect when we are emitting an IR function for an `__init`, and in that case create a local variable to represent the `this` value, and implicitly return that value at the end of the body. From that point on the rest of the compiler Just Works and IR codegen doesn't have to think of an `__init` as being any different than if the user had declared a `static MyThing make(...)` function. Caveats: * C++ users might like to use that naming convention (so `MyThing` as the name instead of `__init`). We can consider that later. * Everybody else might prefer a keyword other than `__init` (e.g., just `init` as in Swift), but I'm keeping this as a "preview" feature for now, rather than something officially supported * Early `return`s from the body of an `__init` aren't going to work right now. * There is currently no provision for automatically synthesizing initializers for `struct` types based on their fields. This seems like a reasonable direction to take in the future. * There is no provision for routing `{}`-based initializer lists over to initializer calls. The two syntaxes probably need to be unified at some point so that doing `MyType x = { a, b, c }` and `let x = MyType(a, b, c)` are semantically equivalent. It is possible that as a byproduct of this change user-defined `__subscript`s might Just Work, but I am guessing there will still be loose ends on that front as well, so I will refrain from looking into that feature until we have a use case that calls for it.
1 parent dcc3af7 commit f9d99fd

9 files changed

+141
-38
lines changed

source/slang/slang-check-decl.cpp

+22-18
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ namespace Slang
7070

7171
void visitAssocTypeDecl(AssocTypeDecl* decl);
7272

73+
void checkCallableDeclCommon(CallableDecl* decl);
74+
7375
void visitFuncDecl(FuncDecl* funcDecl);
7476

7577
void visitParamDecl(ParamDecl* paramDecl);
@@ -149,7 +151,7 @@ namespace Slang
149151

150152
void visitEnumDecl(EnumDecl* decl);
151153

152-
void visitFuncDecl(FuncDecl* funcDecl);
154+
void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
153155

154156
void visitParamDecl(ParamDecl* paramDecl);
155157
};
@@ -1990,11 +1992,11 @@ namespace Slang
19901992
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
19911993
}
19921994

1993-
void SemanticsDeclBodyVisitor::visitFuncDecl(FuncDecl* funcDecl)
1995+
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
19941996
{
1995-
if (auto body = funcDecl->Body)
1997+
if (auto body = decl->Body)
19961998
{
1997-
checkBodyStmt(body, funcDecl);
1999+
checkBodyStmt(body, decl);
19982000
}
19992001
}
20002002

@@ -2546,6 +2548,14 @@ namespace Slang
25462548
}
25472549
}
25482550

2551+
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
2552+
{
2553+
for(auto& paramDecl : decl->GetParameters())
2554+
{
2555+
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
2556+
}
2557+
}
2558+
25492559
void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl)
25502560
{
25512561
auto resultType = funcDecl->ReturnType;
@@ -2559,10 +2569,7 @@ namespace Slang
25592569
}
25602570
funcDecl->ReturnType = resultType;
25612571

2562-
for (auto& para : funcDecl->GetParameters())
2563-
{
2564-
ensureDecl(para, DeclCheckState::ReadyForReference);
2565-
}
2572+
checkCallableDeclCommon(funcDecl);
25662573
}
25672574

25682575
IntegerLiteralValue SemanticsVisitor::GetMinBound(RefPtr<IntVal> val)
@@ -2700,6 +2707,7 @@ namespace Slang
27002707
// significant, and we need to make a choice
27012708
// sooner or later.
27022709
//
2710+
ensureDecl(extDeclRef, DeclCheckState::CanUseExtensionTargetType);
27032711
auto targetType = GetTargetType(extDeclRef);
27042712
return calcThisType(targetType);
27052713
}
@@ -2749,23 +2757,15 @@ namespace Slang
27492757

27502758
void SemanticsDeclHeaderVisitor::visitConstructorDecl(ConstructorDecl* decl)
27512759
{
2752-
for (auto& paramDecl : decl->GetParameters())
2753-
{
2754-
ensureDecl(paramDecl, DeclCheckState::CanUseTypeOfValueDecl);
2755-
}
2756-
27572760
// We need to compute the result tyep for this declaration,
27582761
// since it wasn't filled in for us.
27592762
decl->ReturnType.type = findResultTypeForConstructorDecl(decl);
2763+
2764+
checkCallableDeclCommon(decl);
27602765
}
27612766

27622767
void SemanticsDeclHeaderVisitor::visitSubscriptDecl(SubscriptDecl* decl)
27632768
{
2764-
for (auto& paramDecl : decl->GetParameters())
2765-
{
2766-
ensureDecl(paramDecl, DeclCheckState::CanUseTypeOfValueDecl);
2767-
}
2768-
27692769
decl->ReturnType = CheckUsableType(decl->ReturnType);
27702770

27712771
// If we have a subscript declaration with no accessor declarations,
@@ -2789,6 +2789,8 @@ namespace Slang
27892789
getterDecl->ParentDecl = decl;
27902790
decl->Members.add(getterDecl);
27912791
}
2792+
2793+
checkCallableDeclCommon(decl);
27922794
}
27932795

27942796
void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl)
@@ -2808,6 +2810,8 @@ namespace Slang
28082810
{
28092811
getSink()->diagnose(decl, Diagnostics::accessorMustBeInsideSubscriptOrProperty);
28102812
}
2813+
2814+
checkCallableDeclCommon(decl);
28112815
}
28122816

28132817
GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl)

source/slang/slang-check-impl.h

+6-7
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ namespace Slang
408408
// so that we can add some quality-of-life features for users
409409
// in cases where the compiler crashes
410410
//
411-
void dispatchStmt(Stmt* stmt, FuncDecl* parentFunc, OuterStmtInfo* outerStmts);
411+
void dispatchStmt(Stmt* stmt, FunctionDeclBase* parentFunc, OuterStmtInfo* outerStmts);
412412
void dispatchExpr(Expr* expr);
413413

414414
/// Ensure that a declaration has been checked up to some state
@@ -781,7 +781,7 @@ namespace Slang
781781
// as the tag type for an `enum`
782782
void validateEnumTagType(Type* type, SourceLoc const& loc);
783783

784-
void checkStmt(Stmt* stmt, FuncDecl* outerFunction, OuterStmtInfo* outerStmts);
784+
void checkStmt(Stmt* stmt, FunctionDeclBase* outerFunction, OuterStmtInfo* outerStmts);
785785

786786
void getGenericParams(
787787
GenericDecl* decl,
@@ -1369,20 +1369,19 @@ namespace Slang
13691369
: public SemanticsVisitor
13701370
, StmtVisitor<SemanticsStmtVisitor>
13711371
{
1372-
SemanticsStmtVisitor(SharedSemanticsContext* shared, FuncDecl* parentFunc, OuterStmtInfo* outerStmts)
1372+
SemanticsStmtVisitor(SharedSemanticsContext* shared, FunctionDeclBase* parentFunc, OuterStmtInfo* outerStmts)
13731373
: SemanticsVisitor(shared)
13741374
, m_parentFunc(parentFunc)
13751375
, m_outerStmts(outerStmts)
13761376
{}
13771377

13781378
/// The parent function (if any) that surrounds the statement being checked.
1379-
// TODO: This should probably be a more general case like `CallableDecl`
1380-
FuncDecl* m_parentFunc = nullptr;
1379+
FunctionDeclBase* m_parentFunc = nullptr;
13811380

13821381
/// The linked list of lexically surrounding statements.
13831382
OuterStmtInfo* m_outerStmts = nullptr;
13841383

1385-
FuncDecl* getParentFunc() { return m_parentFunc; }
1384+
FunctionDeclBase* getParentFunc() { return m_parentFunc; }
13861385

13871386
void checkStmt(Stmt* stmt);
13881387

@@ -1433,7 +1432,7 @@ namespace Slang
14331432
: SemanticsVisitor(shared)
14341433
{}
14351434

1436-
void checkBodyStmt(Stmt* stmt, FuncDecl* parentDecl)
1435+
void checkBodyStmt(Stmt* stmt, FunctionDeclBase* parentDecl)
14371436
{
14381437
checkStmt(stmt, parentDecl, nullptr);
14391438
}

source/slang/slang-check-stmt.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace Slang
3535
};
3636
}
3737

38-
void SemanticsVisitor::checkStmt(Stmt* stmt, FuncDecl* parentDecl, OuterStmtInfo* outerStmts)
38+
void SemanticsVisitor::checkStmt(Stmt* stmt, FunctionDeclBase* parentDecl, OuterStmtInfo* outerStmts)
3939
{
4040
if (!stmt) return;
4141
dispatchStmt(stmt, parentDecl, outerStmts);

source/slang/slang-check.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ namespace Slang
223223
translationUnit->getModule()->_collectShaderParams();
224224
}
225225

226-
void SemanticsVisitor::dispatchStmt(Stmt* stmt, FuncDecl* parentFunc, OuterStmtInfo* outerStmts)
226+
void SemanticsVisitor::dispatchStmt(Stmt* stmt, FunctionDeclBase* parentFunc, OuterStmtInfo* outerStmts)
227227
{
228228
SemanticsStmtVisitor visitor(getShared(), parentFunc, outerStmts);
229229
try

source/slang/slang-lookup.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,11 @@ void DoLookupImpl(
518518
session,
519519
name, containerDeclRef, request, result, breadcrumbs);
520520

521-
if( auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>() )
521+
if( containerDeclRef.is<ConstructorDecl>() )
522+
{
523+
thisParameterMode = LookupResultItem::Breadcrumb::ThisParameterMode::Mutating;
524+
}
525+
else if( auto funcDeclRef = containerDeclRef.as<FunctionDeclBase>() )
522526
{
523527
if( funcDeclRef.getDecl()->HasModifier<MutatingAttribute>() )
524528
{

source/slang/slang-lower-to-ir.cpp

+58-10
Original file line numberDiff line numberDiff line change
@@ -804,17 +804,34 @@ LoweredValInfo emitCallToDeclRef(
804804

805805
if( auto ctorDeclRef = funcDeclRef.as<ConstructorDecl>() )
806806
{
807-
// HACK: we know all constructors are builtins for now,
808-
// so we need to emit them as a call to the corresponding
809-
// builtin operation.
810-
//
811-
// TODO: these should all either be intrinsic operations,
812-
// or calls to library functions.
813-
814-
return LoweredValInfo::simple(builder->emitConstructorInst(type, argCount, args));
807+
if(!ctorDeclRef.getDecl()->Body)
808+
{
809+
// HACK: For legacy reasons, all of the built-in initializers
810+
// in the standard library are declared without proper
811+
// intrinsic-op modifiers, so we will assume that an
812+
// initializer without a body should map to `kIROp_Construct`.
813+
//
814+
// TODO: We should make all the initializers in the
815+
// standard library have either a body or a proper
816+
// intrinsic-op modifier.
817+
//
818+
// TODO: We should eliminate `kIROp_Construct` from the
819+
// IR completely, in favor of more detailed/specific ops
820+
// that cover the cases we actually care about.
821+
//
822+
return LoweredValInfo::simple(builder->emitConstructorInst(type, argCount, args));
823+
}
815824
}
816825

817826
// Fallback case is to emit an actual call.
827+
//
828+
// TODO: We are constructing a type that we expect the function
829+
// being called to have here, but that type doesn't account
830+
// for `in` vs. `out`/`inout` parameters, so it could easily
831+
// be wrong. We should sort out why this path in the code
832+
// even needs to be computing a type (rather than taking
833+
// it directly from the declaration).
834+
//
818835
if(!funcType)
819836
{
820837
List<IRType*> argTypes;
@@ -6175,15 +6192,46 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
61756192
}
61766193
}
61776194

6178-
// Lower body
61796195

6196+
// We will now set about emitting the code for the body of
6197+
// the function/callable.
6198+
//
6199+
// In the case of an initializer ("constructor") declaration,
6200+
// the `this` value is not a parameter, but rather a placeholder
6201+
// for the value that will be returned. We thus need to set up
6202+
// a local variable to represent this value.
6203+
//
6204+
auto constructorDecl = as<ConstructorDecl>(decl);
6205+
if(constructorDecl)
6206+
{
6207+
auto thisVar = subContext->irBuilder->emitVar(irResultType);
6208+
subContext->thisVal = LoweredValInfo::ptr(thisVar);
6209+
}
6210+
6211+
// We lower whatever statement was stored on the declaration
6212+
// as the body of the new IR function.
6213+
//
61806214
lowerStmt(subContext, decl->Body);
61816215

61826216
// We need to carefully add a terminator instruction to the end
61836217
// of the body, in case the user didn't do so.
6218+
//
61846219
if (!subContext->irBuilder->getBlock()->getTerminator())
61856220
{
6186-
if(as<IRVoidType>(irResultType))
6221+
if(constructorDecl)
6222+
{
6223+
// A constructor declaration should return the
6224+
// value of the `this` variable that was set
6225+
// up at the start.
6226+
//
6227+
// TODO: This should also apply if any code
6228+
// path in an initializer/constructor attempts
6229+
// to do an early `return;`.
6230+
//
6231+
subContext->irBuilder->emitReturn(
6232+
getSimpleVal(subContext, subContext->thisVal));
6233+
}
6234+
else if(as<IRVoidType>(irResultType))
61876235
{
61886236
// `void`-returning function can get an implicit
61896237
// return on exit of the body statement.

source/slang/slang-parser.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -2447,6 +2447,7 @@ namespace Slang
24472447
{
24482448
RefPtr<ConstructorDecl> decl = new ConstructorDecl();
24492449
parser->FillPosition(decl.Ptr());
2450+
parser->PushScope(decl);
24502451

24512452
// TODO: we need to make sure that all initializers have
24522453
// the same name, but that this name doesn't conflict
@@ -2462,6 +2463,7 @@ namespace Slang
24622463

24632464
decl->Body = parseOptBody(parser);
24642465

2466+
parser->PopScope();
24652467
return decl;
24662468
}
24672469

@@ -2506,6 +2508,7 @@ namespace Slang
25062508
{
25072509
RefPtr<SubscriptDecl> decl = new SubscriptDecl();
25082510
parser->FillPosition(decl.Ptr());
2511+
parser->PushScope(decl);
25092512

25102513
// TODO: the use of this name here is a bit magical...
25112514
decl->nameAndLoc.name = getName(parser, "operator[]");
@@ -2533,6 +2536,7 @@ namespace Slang
25332536
// empty body should be treated like `{ get; }`
25342537
}
25352538

2539+
parser->PopScope();
25362540
return decl;
25372541
}
25382542

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// user-defined-initializer.slang
2+
3+
// Confirm that user-defined initializer/constructor
4+
// methods in a type work as expected.
5+
6+
//TEST(compute):COMPARE_COMPUTE:
7+
//TEST(compute):COMPARE_COMPUTE:-cpu
8+
9+
struct Pair
10+
{
11+
int head;
12+
int tail;
13+
14+
__init(int h, int t)
15+
{
16+
head = h;
17+
tail = t;
18+
}
19+
20+
int getHead() { return head; }
21+
int getTail() { return tail; }
22+
}
23+
24+
int test(int value)
25+
{
26+
Pair p = Pair(value, value+1);
27+
return p.getHead()*16 + p.getTail();
28+
}
29+
30+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
31+
RWStructuredBuffer<int> outputBuffer : register(u0);
32+
33+
[numthreads(4, 1, 1)]
34+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
35+
{
36+
uint tid = dispatchThreadID.x;
37+
int inVal = outputBuffer[tid];
38+
int outVal = test(inVal);
39+
outputBuffer[tid] = outVal;
40+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1
2+
12
3+
23
4+
34

0 commit comments

Comments
 (0)