Skip to content

Commit ac38a7a

Browse files
kaizhangNVdjohansson
authored andcommitted
Implement if(let ...) syntax (shader-slang#3673) (shader-slang#3958)
1 parent 531033f commit ac38a7a

File tree

6 files changed

+388
-1
lines changed

6 files changed

+388
-1
lines changed

docs/user-guide/06-interfaces-generics.md

+58
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,64 @@ T compute<T>(T a1, T a2)
677677
// compute(3, 1) == 2
678678
```
679679

680+
`as` operator can also be used in the `if` predicate to test if an object can be casted to a specific type, once the cast test is successful,
681+
the object can be used in the `if` block as the casted type without the need to retrieve the `Optional<T>::value` property:
682+
```csharp
683+
interface IFoo
684+
{
685+
void foo();
686+
}
687+
688+
struct MyImpl1 : IFoo
689+
{
690+
void foo() { printf("MyImpl1");}
691+
}
692+
693+
struct MyImpl2 : IFoo
694+
{
695+
void foo() { printf("MyImpl2");}
696+
}
697+
698+
struct MyImpl3 : IFoo
699+
{
700+
void foo() { printf("MyImpl3");}
701+
}
702+
703+
void test(IFoo foo)
704+
{
705+
// This syntax will be desugared to the following:
706+
// {
707+
// Optional<MyImpl1> $OptVar = foo as MyImpl1;
708+
// if ($OptVar.hasValue)
709+
// {
710+
// MyImpl1 t = $OptVar.value;
711+
// t.foo();
712+
// }
713+
// else if ...
714+
// }
715+
if (let t = foo as MyImpl1) // t is of type MyImpl1
716+
{
717+
t.foo();
718+
}
719+
else if (let t = foo as MyImpl2) // t is of type MyImpl2
720+
{
721+
t.foo();
722+
}
723+
else
724+
printf("fail");
725+
}
726+
727+
void main()
728+
{
729+
MyImpl1 v1;
730+
test(v1);
731+
732+
MyImpl2 v2;
733+
test(v2);
734+
}
735+
736+
```
737+
680738
Extensions to Interfaces
681739
-----------------------------
682740

source/slang/slang-parser.cpp

+108-1
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ namespace Slang
199199
Stmt* parseLabelStatement();
200200
DeclStmt* parseVarDeclrStatement(Modifiers modifiers);
201201
IfStmt* parseIfStatement();
202+
Stmt* parseIfLetStatement();
202203
ForStmt* ParseForStatement();
203204
WhileStmt* ParseWhileStatement();
204205
DoWhileStmt* ParseDoWhileStatement();
@@ -5276,7 +5277,16 @@ namespace Slang
52765277
if (LookAheadToken(TokenType::LBrace))
52775278
statement = parseBlockStatement();
52785279
else if (LookAheadToken("if"))
5279-
statement = parseIfStatement();
5280+
{
5281+
if(LookAheadToken("let", 2))
5282+
{
5283+
statement = parseIfLetStatement();
5284+
}
5285+
else
5286+
{
5287+
statement = parseIfStatement();
5288+
}
5289+
}
52805290
else if (LookAheadToken("for"))
52815291
statement = ParseForStatement();
52825292
else if (LookAheadToken("while"))
@@ -5579,6 +5589,103 @@ namespace Slang
55795589
return varDeclrStatement;
55805590
}
55815591

5592+
static Expr* constructIfLetPredicate(Parser* parser, VarExpr* varExpr)
5593+
{
5594+
// create a "var.hasValue" expression
5595+
MemberExpr* memberExpr = parser->astBuilder->create<MemberExpr>();
5596+
memberExpr->baseExpression = varExpr;
5597+
parser->FillPosition(memberExpr);
5598+
memberExpr->name = getName(parser, "hasValue");
5599+
5600+
return memberExpr;
5601+
}
5602+
5603+
// Parse the syntax 'if (let var = X as Y)'
5604+
Stmt* Parser::parseIfLetStatement()
5605+
{
5606+
ScopeDecl* scopeDecl = astBuilder->create<ScopeDecl>();
5607+
pushScopeAndSetParent(scopeDecl);
5608+
5609+
SeqStmt* newBody = astBuilder->create<SeqStmt>();
5610+
5611+
IfStmt* ifStatement = astBuilder->create<IfStmt>();
5612+
FillPosition(ifStatement);
5613+
ReadToken("if");
5614+
ReadToken(TokenType::LParent);
5615+
5616+
// parse 'let var = X as Y'
5617+
ReadToken("let");
5618+
auto identifierToken = ReadToken(TokenType::Identifier);
5619+
ReadToken(TokenType::OpAssign);
5620+
auto initExpr = ParseInitExpr();
5621+
5622+
// insert 'let tempVarDecl = X as Y;'
5623+
auto tempVarDecl = astBuilder->create<LetDecl>();
5624+
tempVarDecl->nameAndLoc = NameLoc(getName(this, "$OptVar"), identifierToken.loc);
5625+
tempVarDecl->initExpr = initExpr;
5626+
AddMember(currentScope->containerDecl, tempVarDecl);
5627+
5628+
DeclStmt* tmpVarDeclStmt = astBuilder->create<DeclStmt>();
5629+
FillPosition(tmpVarDeclStmt);
5630+
tmpVarDeclStmt->decl = tempVarDecl;
5631+
newBody->stmts.add(tmpVarDeclStmt);
5632+
5633+
// construct 'if (tempVarDecl.hasValue == true)'
5634+
VarExpr* tempVarExpr = astBuilder->create<VarExpr>();
5635+
tempVarExpr->scope = currentScope;
5636+
FillPosition(tempVarExpr);
5637+
tempVarExpr->name = tempVarDecl->getName();
5638+
ifStatement->predicate = constructIfLetPredicate(this, tempVarExpr);
5639+
5640+
ReadToken(TokenType::RParent);
5641+
5642+
// Create a new scope surrounding the positive statement, will be used for
5643+
// the variable declared in the if_let syntax
5644+
ScopeDecl* positiveScopeDecl = astBuilder->create<ScopeDecl>();
5645+
pushScopeAndSetParent(positiveScopeDecl);
5646+
ifStatement->positiveStatement = ParseStatement(ifStatement);
5647+
PopScope();
5648+
5649+
if (LookAheadToken("else"))
5650+
{
5651+
ReadToken("else");
5652+
ifStatement->negativeStatement = ParseStatement(ifStatement);
5653+
}
5654+
5655+
if (ifStatement->positiveStatement)
5656+
{
5657+
auto seqPositiveStmt = as<SeqStmt>(ifStatement->positiveStatement);
5658+
if (!seqPositiveStmt)
5659+
{
5660+
seqPositiveStmt = astBuilder->create<SeqStmt>();
5661+
}
5662+
5663+
MemberExpr* memberExpr = astBuilder->create<MemberExpr>();
5664+
memberExpr->baseExpression = tempVarExpr;
5665+
memberExpr->name = getName(this, "value");
5666+
5667+
auto varDecl = astBuilder->create<LetDecl>();
5668+
varDecl->nameAndLoc = NameLoc(identifierToken.getName(), identifierToken.loc);
5669+
varDecl->initExpr = memberExpr;
5670+
5671+
DeclStmt* varDeclrStatement = astBuilder->create<DeclStmt>();
5672+
varDeclrStatement->decl = varDecl;
5673+
5674+
// Add scope to the variable declared in the if_let syntax such
5675+
// that this variable cannot be used outside the positive statement
5676+
AddMember(positiveScopeDecl, varDecl);
5677+
5678+
seqPositiveStmt->stmts.add(varDeclrStatement);
5679+
seqPositiveStmt->stmts.add(ifStatement->positiveStatement);
5680+
ifStatement->positiveStatement = seqPositiveStmt;
5681+
}
5682+
5683+
newBody->stmts.add(ifStatement);
5684+
PopScope();
5685+
5686+
return newBody;
5687+
}
5688+
55825689
IfStmt* Parser::parseIfStatement()
55835690
{
55845691
IfStmt* ifStatement = astBuilder->create<IfStmt>();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -compute -shaderobj
2+
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cuda -compute -shaderobj
3+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj
4+
5+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
6+
RWStructuredBuffer<int> outputBuffer;
7+
8+
9+
interface IFoo
10+
{
11+
int foo(int a);
12+
}
13+
14+
struct MyImpl1 : IFoo
15+
{
16+
int foo(int a) { return a; }
17+
}
18+
19+
struct MyImpl2 : IFoo
20+
{
21+
int foo(int a) { return a + 5; }
22+
}
23+
24+
int test(IFoo foo, int idx)
25+
{
26+
int val = 0;
27+
if (let a = foo as MyImpl1)
28+
{
29+
val = a.foo(idx);
30+
}
31+
else if (let a = foo as MyImpl2)
32+
{
33+
val = a.foo(idx);
34+
}
35+
return (val);
36+
}
37+
38+
int test1<T>(T t)
39+
{
40+
if (let a = t as uint)
41+
{
42+
return 1;
43+
}
44+
else if(let a = t as float)
45+
{
46+
return 2;
47+
}
48+
else if (let a = t as double)
49+
{
50+
return 3;
51+
}
52+
else if (let a = t as int)
53+
{
54+
return 4;
55+
}
56+
else if (let a = t as uint64_t)
57+
{
58+
return 5;
59+
}
60+
else
61+
{
62+
return 6;
63+
}
64+
}
65+
66+
67+
[numthreads(1, 1, 1)]
68+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
69+
{
70+
MyImpl1 impl1;
71+
MyImpl2 impl2;
72+
// CHECK: 1
73+
// CHECK: 7
74+
outputBuffer[0] = test(impl1, 1);
75+
outputBuffer[1] = test(impl2, 2);
76+
77+
// CHECK: 1
78+
outputBuffer[2] = test1(2U);
79+
// CHECK: 2
80+
outputBuffer[3] = test1(2.0f);
81+
// CHECK: 3
82+
outputBuffer[4] = test1(2.0lf);
83+
// CHECK: 4
84+
outputBuffer[5] = test1(2);
85+
// CHECK: 5
86+
outputBuffer[6] = test1(2LLU);
87+
// CHECK: 6
88+
outputBuffer[7] = test1(impl1);
89+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target glsl -stage compute -entry computeMain
2+
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain
3+
//TEST:SIMPLE(filecheck=CHECK): -target cuda -stage compute -entry computeMain
4+
//TEST:SIMPLE(filecheck=CHECK): -target cpp -stage compute -entry computeMain
5+
6+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
7+
RWStructuredBuffer<int> outputBuffer;
8+
9+
10+
interface IFoo
11+
{
12+
int foo(int a);
13+
}
14+
15+
struct MyImpl : IFoo
16+
{
17+
int foo(int a) { return a; }
18+
}
19+
20+
struct MyImpl1 : IFoo
21+
{
22+
int foo(int a) { return a; }
23+
}
24+
25+
int test(IFoo foo, int idx)
26+
{
27+
int val = 0;
28+
if (let a = foo as MyImpl)
29+
{
30+
val = a.foo(idx);
31+
}
32+
// CHECK: error 30015: undefined identifier 'a'.
33+
else if(a == none)
34+
{
35+
val = -1;
36+
}
37+
else
38+
{
39+
// CHECK: error 30015: undefined identifier 'a'.
40+
if (a == none)
41+
{
42+
val = -1;
43+
}
44+
}
45+
return (val);
46+
}
47+
48+
[numthreads(4, 1, 1)]
49+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
50+
{
51+
MyImpl1 impl;
52+
outputBuffer[dispatchThreadID.x] = test(impl, dispatchThreadID.x);
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target glsl -stage compute -entry computeMain
2+
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage compute -entry computeMain
3+
//TEST:SIMPLE(filecheck=CHECK): -target cuda -stage compute -entry computeMain
4+
//TEST:SIMPLE(filecheck=CHECK): -target cpp -stage compute -entry computeMain
5+
6+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
7+
RWStructuredBuffer<int> outputBuffer;
8+
9+
10+
interface IFoo
11+
{
12+
int foo(int a);
13+
}
14+
15+
struct MyImpl : IFoo
16+
{
17+
int foo(int a) { return a; }
18+
}
19+
20+
struct MyImpl1 : IFoo
21+
{
22+
int foo(int a) { return a; }
23+
}
24+
25+
int test(IFoo foo, int idx)
26+
{
27+
int val = 0;
28+
// CHECK: error 20002: syntax error.
29+
if ((let a = foo as MyImpl))
30+
{
31+
val = a.foo(idx);
32+
}
33+
return (val);
34+
}
35+
36+
37+
[numthreads(4, 1, 1)]
38+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
39+
{
40+
MyImpl impl;
41+
outputBuffer[dispatchThreadID.x] = test(impl, dispatchThreadID.x);
42+
}

0 commit comments

Comments
 (0)