Skip to content

Commit 8e0e6b6

Browse files
csyongheTim Foley
authored and
Tim Foley
committed
Support simple generics syntax (shader-slang#319)
* Support simple generics syntax. This commit enables simpler generics syntax, e.g. T test<T>(T arg) {} or struct Gen<T>{T x;}; * Support simple generics syntax. This commit enables simpler generics syntax, e.g. T test<T>(T arg) {} or struct Gen<T>{T x;}; * add expected test result for compute/generics-syntax.slang
1 parent 563fc0c commit 8e0e6b6

File tree

4 files changed

+203
-113
lines changed

4 files changed

+203
-113
lines changed

source/slang/lower-to-ir.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -3446,6 +3446,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
34463446
{
34473447
if (auto innerFuncDecl = genDecl->inner->As<FuncDecl>())
34483448
return lowerFuncDecl(innerFuncDecl);
3449+
else if (auto innerStructDecl = genDecl->inner->As<StructDecl>())
3450+
return LoweredValInfo();
34493451
SLANG_RELEASE_ASSERT(false);
34503452
UNREACHABLE_RETURN(LoweredValInfo());
34513453
}

source/slang/parser.cpp

+168-113
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ namespace Slang
9292
bool LookAheadToken(const char * string, int offset = 0);
9393
void parseSourceFile(ModuleDecl* program);
9494
RefPtr<ModuleDecl> ParseProgram();
95-
RefPtr<StructDecl> ParseStruct();
95+
RefPtr<Decl> ParseStruct();
9696
RefPtr<ClassDecl> ParseClass();
9797
RefPtr<Stmt> ParseStatement();
9898
RefPtr<Stmt> parseBlockStatement();
@@ -976,6 +976,101 @@ namespace Slang
976976
}
977977
}
978978

979+
static RefPtr<Decl> ParseGenericParamDecl(
980+
Parser* parser,
981+
RefPtr<GenericDecl> genericDecl)
982+
{
983+
// simple syntax to introduce a value parameter
984+
if (AdvanceIf(parser, "let"))
985+
{
986+
// default case is a type parameter
987+
auto paramDecl = new GenericValueParamDecl();
988+
paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
989+
if (AdvanceIf(parser, TokenType::Colon))
990+
{
991+
paramDecl->type = parser->ParseTypeExp();
992+
}
993+
if (AdvanceIf(parser, TokenType::OpAssign))
994+
{
995+
paramDecl->initExpr = parser->ParseInitExpr();
996+
}
997+
return paramDecl;
998+
}
999+
else
1000+
{
1001+
// default case is a type parameter
1002+
auto paramDecl = new GenericTypeParamDecl();
1003+
parser->FillPosition(paramDecl);
1004+
paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
1005+
if (AdvanceIf(parser, TokenType::Colon))
1006+
{
1007+
// The user is apply a constraint to this type parameter...
1008+
1009+
auto paramConstraint = new GenericTypeConstraintDecl();
1010+
parser->FillPosition(paramConstraint);
1011+
1012+
auto paramType = DeclRefType::Create(
1013+
parser->getSession(),
1014+
DeclRef<Decl>(paramDecl, nullptr));
1015+
1016+
auto paramTypeExpr = new SharedTypeExpr();
1017+
paramTypeExpr->loc = paramDecl->loc;
1018+
paramTypeExpr->base.type = paramType;
1019+
paramTypeExpr->type = QualType(getTypeType(paramType));
1020+
1021+
paramConstraint->sub = TypeExp(paramTypeExpr);
1022+
paramConstraint->sup = parser->ParseTypeExp();
1023+
1024+
AddMember(genericDecl, paramConstraint);
1025+
1026+
1027+
}
1028+
if (AdvanceIf(parser, TokenType::OpAssign))
1029+
{
1030+
paramDecl->initType = parser->ParseTypeExp();
1031+
}
1032+
return paramDecl;
1033+
}
1034+
}
1035+
1036+
template<typename TFunc>
1037+
static void ParseGenericDeclImpl(
1038+
Parser* parser, GenericDecl* decl, const TFunc & parseInnerFunc)
1039+
{
1040+
parser->ReadToken(TokenType::OpLess);
1041+
parser->genericDepth++;
1042+
while (!parser->LookAheadToken(TokenType::OpGreater))
1043+
{
1044+
AddMember(decl, ParseGenericParamDecl(parser, decl));
1045+
1046+
if (parser->LookAheadToken(TokenType::OpGreater))
1047+
break;
1048+
1049+
parser->ReadToken(TokenType::Comma);
1050+
}
1051+
parser->genericDepth--;
1052+
parser->ReadToken(TokenType::OpGreater);
1053+
decl->inner = parseInnerFunc(decl);
1054+
decl->inner->ParentDecl = decl;
1055+
// A generic decl hijacks the name of the declaration
1056+
// it wraps, so that lookup can find it.
1057+
if (decl->inner)
1058+
{
1059+
decl->nameAndLoc = decl->inner->nameAndLoc;
1060+
decl->loc = decl->inner->loc;
1061+
}
1062+
}
1063+
1064+
static RefPtr<RefObject> ParseGenericDecl(Parser* parser, void*)
1065+
{
1066+
RefPtr<GenericDecl> decl = new GenericDecl();
1067+
parser->FillPosition(decl.Ptr());
1068+
parser->PushScope(decl.Ptr());
1069+
ParseGenericDeclImpl(parser, decl.Ptr(), [=](GenericDecl* genDecl) {return ParseSingleDecl(parser, genDecl); });
1070+
parser->PopScope();
1071+
return decl;
1072+
}
1073+
9791074
static void parseParameterList(
9801075
Parser* parser,
9811076
RefPtr<CallableDecl> decl)
@@ -1004,29 +1099,62 @@ namespace Slang
10041099
}
10051100
}
10061101

1007-
static void ParseFuncDeclHeader(
1102+
static RefPtr<Decl> ParseFuncDeclHeader(
10081103
Parser* parser,
10091104
DeclaratorInfo const& declaratorInfo,
1010-
RefPtr<FuncDecl> decl)
1105+
RefPtr<FuncDecl> decl,
1106+
RefPtr<GenericDecl> genDecl)
10111107
{
1012-
parser->PushScope(decl.Ptr());
1108+
RefPtr<Decl> retDecl = decl;
10131109

10141110
parser->FillPosition(decl.Ptr());
10151111
decl->loc = declaratorInfo.nameAndLoc.loc;
10161112

10171113
decl->nameAndLoc = declaratorInfo.nameAndLoc;
1114+
1115+
// if return type is a DeclRef type, we need to update its scope to use this function decl's scope
1116+
// so that LookUp can find the generic type parameters declared after the function name
1117+
if (auto declRefRetType = declaratorInfo.typeSpec.As<DeclRefExpr>())
1118+
declRefRetType->scope = parser->currentScope;
1119+
10181120
decl->ReturnType = TypeExp(declaratorInfo.typeSpec);
1019-
parseParameterList(parser, decl);
1020-
ParseOptSemantics(parser, decl.Ptr());
1121+
auto parseFuncDeclHeaderInner = [&](GenericDecl *)
1122+
{
1123+
parseParameterList(parser, decl);
1124+
ParseOptSemantics(parser, decl.Ptr());
1125+
return decl;
1126+
};
1127+
1128+
if (parser->LookAheadToken(TokenType::OpLess))
1129+
{
1130+
// parse generic parameters
1131+
ParseGenericDeclImpl(parser, genDecl.Ptr(), parseFuncDeclHeaderInner);
1132+
retDecl = genDecl;
1133+
}
1134+
else
1135+
parseFuncDeclHeaderInner(nullptr);
1136+
1137+
return retDecl;
10211138
}
10221139

10231140
static RefPtr<Decl> ParseFuncDecl(
10241141
Parser* parser,
10251142
ContainerDecl* /*containerDecl*/,
1026-
DeclaratorInfo const& declaratorInfo)
1143+
DeclaratorInfo const& declaratorInfo,
1144+
bool isGeneric)
10271145
{
10281146
RefPtr<FuncDecl> decl = new FuncDecl();
1029-
ParseFuncDeclHeader(parser, declaratorInfo, decl);
1147+
RefPtr<Decl> retDecl = decl;
1148+
RefPtr<GenericDecl> genDecl;
1149+
if (isGeneric)
1150+
{
1151+
genDecl = new GenericDecl();
1152+
parser->FillPosition(genDecl);
1153+
parser->PushScope(genDecl);
1154+
retDecl = genDecl;
1155+
}
1156+
parser->PushScope(decl.Ptr());
1157+
ParseFuncDeclHeader(parser, declaratorInfo, decl, genDecl);
10301158

10311159
if (AdvanceIf(parser, TokenType::Semicolon))
10321160
{
@@ -1038,7 +1166,11 @@ namespace Slang
10381166
}
10391167

10401168
parser->PopScope();
1041-
return decl;
1169+
if (isGeneric)
1170+
{
1171+
parser->PopScope();
1172+
}
1173+
return retDecl;
10421174
}
10431175

10441176
static RefPtr<VarDeclBase> CreateVarDeclForContext(
@@ -1594,7 +1726,8 @@ namespace Slang
15941726
// matter unless we actually decide to support function-type parameters,
15951727
// using C syntax.
15961728
//
1597-
if( parser->tokenReader.PeekTokenType() == TokenType::LParent
1729+
if ((parser->tokenReader.PeekTokenType() == TokenType::LParent ||
1730+
parser->tokenReader.PeekTokenType() == TokenType::OpLess)
15981731

15991732
// Only parse as a function if we didn't already see mutually-exclusive
16001733
// constructs when parsing the declarator.
@@ -1603,7 +1736,7 @@ namespace Slang
16031736
{
16041737
// Looks like a function, so parse it like one.
16051738
UnwrapDeclarator(initDeclarator, &declaratorInfo);
1606-
return ParseFuncDecl(parser, containerDecl, declaratorInfo);
1739+
return ParseFuncDecl(parser, containerDecl, declaratorInfo, parser->tokenReader.PeekTokenType() == TokenType::OpLess);
16071740
}
16081741

16091742
// Otherwise we are looking at a variable declaration, which could be one in a sequence...
@@ -2047,101 +2180,6 @@ namespace Slang
20472180
return blockVarDecl;
20482181
}
20492182

2050-
2051-
2052-
2053-
static RefPtr<Decl> ParseGenericParamDecl(
2054-
Parser* parser,
2055-
RefPtr<GenericDecl> genericDecl)
2056-
{
2057-
// simple syntax to introduce a value parameter
2058-
if (AdvanceIf(parser, "let"))
2059-
{
2060-
// default case is a type parameter
2061-
auto paramDecl = new GenericValueParamDecl();
2062-
paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
2063-
if (AdvanceIf(parser, TokenType::Colon))
2064-
{
2065-
paramDecl->type = parser->ParseTypeExp();
2066-
}
2067-
if (AdvanceIf(parser, TokenType::OpAssign))
2068-
{
2069-
paramDecl->initExpr = parser->ParseInitExpr();
2070-
}
2071-
return paramDecl;
2072-
}
2073-
else
2074-
{
2075-
// default case is a type parameter
2076-
auto paramDecl = new GenericTypeParamDecl();
2077-
parser->FillPosition(paramDecl);
2078-
paramDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
2079-
if (AdvanceIf(parser, TokenType::Colon))
2080-
{
2081-
// The user is apply a constraint to this type parameter...
2082-
2083-
auto paramConstraint = new GenericTypeConstraintDecl();
2084-
parser->FillPosition(paramConstraint);
2085-
2086-
auto paramType = DeclRefType::Create(
2087-
parser->getSession(),
2088-
DeclRef<Decl>(paramDecl, nullptr));
2089-
2090-
auto paramTypeExpr = new SharedTypeExpr();
2091-
paramTypeExpr->loc = paramDecl->loc;
2092-
paramTypeExpr->base.type = paramType;
2093-
paramTypeExpr->type = QualType(getTypeType(paramType));
2094-
2095-
paramConstraint->sub = TypeExp(paramTypeExpr);
2096-
paramConstraint->sup = parser->ParseTypeExp();
2097-
2098-
AddMember(genericDecl, paramConstraint);
2099-
2100-
2101-
}
2102-
if (AdvanceIf(parser, TokenType::OpAssign))
2103-
{
2104-
paramDecl->initType = parser->ParseTypeExp();
2105-
}
2106-
return paramDecl;
2107-
}
2108-
}
2109-
2110-
static RefPtr<RefObject> ParseGenericDecl(
2111-
Parser* parser, void* /*userData*/)
2112-
{
2113-
RefPtr<GenericDecl> decl = new GenericDecl();
2114-
parser->FillPosition(decl.Ptr());
2115-
parser->PushScope(decl.Ptr());
2116-
2117-
parser->ReadToken(TokenType::OpLess);
2118-
parser->genericDepth++;
2119-
while (!parser->LookAheadToken(TokenType::OpGreater))
2120-
{
2121-
AddMember(decl, ParseGenericParamDecl(parser, decl));
2122-
2123-
if( parser->LookAheadToken(TokenType::OpGreater) )
2124-
break;
2125-
2126-
parser->ReadToken(TokenType::Comma);
2127-
}
2128-
parser->genericDepth--;
2129-
parser->ReadToken(TokenType::OpGreater);
2130-
2131-
decl->inner = ParseSingleDecl(parser, decl.Ptr());
2132-
2133-
// A generic decl hijacks the name of the declaration
2134-
// it wraps, so that lookup can find it.
2135-
if (decl->inner)
2136-
{
2137-
decl->nameAndLoc = decl->inner->nameAndLoc;
2138-
decl->loc = decl->inner->loc;
2139-
}
2140-
2141-
parser->PopScope();
2142-
return decl;
2143-
}
2144-
21452183
static RefPtr<RefObject> ParseExtensionDecl(Parser* parser, void* /*userData*/)
21462184
{
21472185
RefPtr<ExtensionDecl> decl = new ExtensionDecl();
@@ -2640,22 +2678,39 @@ namespace Slang
26402678
return program;
26412679
}
26422680

2643-
RefPtr<StructDecl> Parser::ParseStruct()
2681+
RefPtr<Decl> Parser::ParseStruct()
26442682
{
26452683
RefPtr<StructDecl> rs = new StructDecl();
2684+
RefPtr<Decl> retDecl = rs;
26462685
FillPosition(rs.Ptr());
26472686
ReadToken("struct");
26482687

26492688
// TODO: support `struct` declaration without tag
26502689
rs->nameAndLoc = expectIdentifier(this);
26512690

2652-
// We allow for an inheritance clause on a `struct`
2653-
// so that it can conform to interfaces.
2654-
parseOptionalInheritanceClause(this, rs.Ptr());
2655-
2656-
parseAggTypeDeclBody(this, rs.Ptr());
2691+
auto parseStructInner = [&](GenericDecl*)
2692+
{
2693+
// We allow for an inheritance clause on a `struct`
2694+
// so that it can conform to interfaces.
2695+
parseOptionalInheritanceClause(this, rs.Ptr());
2696+
parseAggTypeDeclBody(this, rs.Ptr());
2697+
return rs;
2698+
};
26572699

2658-
return rs;
2700+
if (LookAheadToken(TokenType::OpLess))
2701+
{
2702+
RefPtr<GenericDecl> genDecl = new GenericDecl();
2703+
FillPosition(genDecl.Ptr());
2704+
PushScope(genDecl);
2705+
ParseGenericDeclImpl(this, genDecl.Ptr(), parseStructInner);
2706+
PopScope();
2707+
retDecl = genDecl;
2708+
}
2709+
else
2710+
{
2711+
parseStructInner(nullptr);
2712+
}
2713+
return retDecl;
26592714
}
26602715

26612716
RefPtr<ClassDecl> Parser::ParseClass()

0 commit comments

Comments
 (0)