Skip to content

Commit e370fe2

Browse files
authored
Merge pull request shader-slang#335 from csyonghe/master
Support nested generic types (e.g. L<T<S>>)
2 parents 6924239 + d55b56b commit e370fe2

16 files changed

+302
-59
lines changed

source/slang/check.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -7003,8 +7003,18 @@ namespace Slang
70037003
}
70047004
}
70057005

7006-
// TODO: need to fill in constraints here...
7007-
7006+
// create default substitution arguments for constraints
7007+
for (auto mm : genericDecl->Members)
7008+
{
7009+
if (auto genericTypeConstraintDecl = mm.As<GenericTypeConstraintDecl>())
7010+
{
7011+
RefPtr<DeclaredSubtypeWitness> witness = new DeclaredSubtypeWitness();
7012+
witness->declRef = makeDeclRef(genericTypeConstraintDecl.Ptr());
7013+
witness->sub = genericTypeConstraintDecl->sub.type;
7014+
witness->sup = genericTypeConstraintDecl->sup.type;
7015+
subst->args.Add(witness);
7016+
}
7017+
}
70087018
return subst;
70097019
}
70107020
return parentSubst;

source/slang/emit.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -4503,8 +4503,12 @@ emitDeclImpl(decl, nullptr);
45034503
return name;
45044504
}
45054505

4506-
// Special case (2): not implemented yet.
4507-
4506+
// Special case (2)
4507+
if (declRef.GetParent().decl->As<AggTypeDecl>())
4508+
{
4509+
name.append(declRef.decl->nameAndLoc.name->text);
4510+
return name;
4511+
}
45084512
// General case:
45094513
name.append(getMangledName(declRef));
45104514
return name;

source/slang/ir-insts.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ struct IRWitnessTableEntry : IRUser
304304
// to the IR values that satisfy those requirements.
305305
struct IRWitnessTable : IRGlobalValue
306306
{
307+
RefPtr<GenericDecl> genericDecl;
308+
DeclRef<Decl> subTypeDeclRef, supTypeDeclRef;
307309
IRValueList<IRWitnessTableEntry> entries;
308310
};
309311

@@ -341,6 +343,7 @@ struct SharedIRBuilder
341343

342344
Dictionary<IRInstKey, IRInst*> globalValueNumberingMap;
343345
Dictionary<IRConstantKey, IRConstant*> constantMap;
346+
Dictionary<String, IRWitnessTable*> witnessTableMap;
344347
};
345348

346349
struct IRBuilderSourceLocRAII;
@@ -417,7 +420,7 @@ struct IRBuilder
417420
IRValue* const* args);
418421

419422
IRModule* createModule();
420-
423+
421424
IRFunc* createFunc();
422425
IRGlobalVar* createGlobalVar(
423426
IRType* valueType);
@@ -427,7 +430,8 @@ struct IRBuilder
427430
IRWitnessTable* witnessTable,
428431
IRValue* requirementKey,
429432
IRValue* satisfyingVal);
430-
433+
IRWitnessTable* lookupWitnessTable(String mangledName);
434+
void registerWitnessTable(IRWitnessTable* table);
431435
IRBlock* createBlock();
432436
IRBlock* emitBlock();
433437

source/slang/ir.cpp

+139-38
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ namespace Slang
6060
// clear out the old value
6161
if (usedVal)
6262
{
63-
*prevLink = nextUse;
63+
if (prevLink)
64+
*prevLink = nextUse;
6465
}
6566

6667
init(user, usedVal);
@@ -934,6 +935,19 @@ namespace Slang
934935
return entry;
935936
}
936937

938+
IRWitnessTable * IRBuilder::lookupWitnessTable(String mangledName)
939+
{
940+
IRWitnessTable * result;
941+
if (sharedBuilder->witnessTableMap.TryGetValue(mangledName, result))
942+
return result;
943+
return nullptr;
944+
}
945+
946+
947+
void IRBuilder::registerWitnessTable(IRWitnessTable * table)
948+
{
949+
sharedBuilder->witnessTableMap[table->mangledName] = table;
950+
}
937951

938952
IRBlock* IRBuilder::createBlock()
939953
{
@@ -1396,7 +1410,7 @@ namespace Slang
13961410
struct IRDumpContext
13971411
{
13981412
StringBuilder* builder;
1399-
int indent;
1413+
int indent = 0;
14001414

14011415
UInt idCounter = 1;
14021416
Dictionary<IRValue*, UInt> mapValueToID;
@@ -1588,7 +1602,7 @@ namespace Slang
15881602
}
15891603
else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val))
15901604
{
1591-
dumpOperand(context, proxyVal->inst);
1605+
dumpOperand(context, proxyVal->inst.usedValue);
15921606
}
15931607
else
15941608
{
@@ -2095,6 +2109,17 @@ namespace Slang
20952109
dump(context, "}\n");
20962110
}
20972111

2112+
2113+
String dumpIRFunc(IRFunc* func)
2114+
{
2115+
IRDumpContext dumpContext;
2116+
StringBuilder sbDump;
2117+
dumpContext.builder = &sbDump;
2118+
dumpIRFunc(&dumpContext, func);
2119+
auto strFunc = sbDump.ToString();
2120+
return strFunc;
2121+
}
2122+
20982123
void dumpIRGlobalVar(
20992124
IRDumpContext* context,
21002125
IRGlobalVar* var)
@@ -3199,7 +3224,6 @@ namespace Slang
31993224
case kIROp_Func:
32003225
case kIROp_witness_table:
32013226
return cloneGlobalValue(this, (IRGlobalValue*) originalValue);
3202-
32033227
case kIROp_boolConst:
32043228
{
32053229
IRConstant* c = (IRConstant*)originalValue;
@@ -3246,7 +3270,7 @@ namespace Slang
32463270
{
32473271
auto proxyVal = witness.Value.As<IRProxyVal>();
32483272
SLANG_ASSERT(proxyVal);
3249-
return proxyVal->inst;
3273+
return proxyVal->inst.usedValue;
32503274
}
32513275
}
32523276
}
@@ -3270,16 +3294,20 @@ namespace Slang
32703294
}
32713295
}
32723296

3297+
IRValue* cloneValue(
3298+
IRSpecContextBase* context,
3299+
IRValue* originalValue);
3300+
32733301
RefPtr<Val> cloneSubstitutionArg(
32743302
IRSpecContext* context,
32753303
Val* val)
32763304
{
32773305
if (auto proxyVal = dynamic_cast<IRProxyVal*>(val))
32783306
{
3279-
auto newIRVal = context->maybeCloneValue(proxyVal->inst);
3307+
auto newIRVal = cloneValue(context, proxyVal->inst.usedValue);
32803308

32813309
RefPtr<IRProxyVal> newProxyVal = new IRProxyVal();
3282-
newProxyVal->inst = newIRVal;
3310+
newProxyVal->inst.init(nullptr, newIRVal);
32833311
return newProxyVal;
32843312
}
32853313
else if (auto type = dynamic_cast<Type*>(val))
@@ -3307,7 +3335,7 @@ namespace Slang
33073335
for (auto arg : genSubst->args)
33083336
{
33093337
auto newArg = cloneSubstitutionArg(context, arg);
3310-
newSubst->args.Add(arg);
3338+
newSubst->args.Add(newArg);
33113339
}
33123340
return newSubst;
33133341
}
@@ -3436,7 +3464,7 @@ namespace Slang
34363464
}
34373465

34383466
IRWitnessTable* cloneWitnessTableImpl(
3439-
IRSpecContext* context,
3467+
IRSpecContextBase* context,
34403468
IRWitnessTable* originalTable,
34413469
IROriginalValuesForClone const& originalValues)
34423470
{
@@ -3445,7 +3473,9 @@ namespace Slang
34453473

34463474
auto mangledName = originalTable->mangledName;
34473475
clonedTable->mangledName = mangledName;
3448-
3476+
clonedTable->genericDecl = originalTable->genericDecl;
3477+
clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef;
3478+
clonedTable->supTypeDeclRef = originalTable->supTypeDeclRef;
34493479
cloneDecorations(context, clonedTable, originalTable);
34503480

34513481
// Clone the entries in the witness table as well
@@ -3463,7 +3493,7 @@ namespace Slang
34633493
}
34643494

34653495
IRWitnessTable* cloneWitnessTableWithoutRegistering(
3466-
IRSpecContext* context,
3496+
IRSpecContextBase* context,
34673497
IRWitnessTable* originalTable)
34683498
{
34693499
return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone());
@@ -4008,7 +4038,7 @@ namespace Slang
40084038
SLANG_ASSERT(table);
40094039
table = cloneWitnessTableWithoutRegistering(context, (IRWitnessTable*)(table));
40104040
IRProxyVal * tableVal = new IRProxyVal();
4011-
tableVal->inst = table;
4041+
tableVal->inst.init(nullptr, table);
40124042
paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal));
40134043
}
40144044
}
@@ -4228,7 +4258,7 @@ namespace Slang
42284258
// the pointed-to value and not the proxy type-level `Val`
42294259
// instead.
42304260

4231-
return context->maybeCloneValue(proxyVal->inst);
4261+
return context->maybeCloneValue(proxyVal->inst.usedValue);
42324262
}
42334263
else
42344264
{
@@ -4394,6 +4424,69 @@ namespace Slang
43944424
return newSubst;
43954425
}
43964426

4427+
IRFunc* getSpecializedFunc(
4428+
IRSharedGenericSpecContext* sharedContext,
4429+
IRFunc* genericFunc,
4430+
DeclRef<Decl> specDeclRef);
4431+
4432+
IRWitnessTable* specializeWitnessTable(IRSharedGenericSpecContext * sharedContext, IRWitnessTable* originalTable, DeclRef<Decl> specDeclRef)
4433+
{
4434+
// First, we want to see if an existing specialization
4435+
// has already been made. To do that we will need to
4436+
// compute the mangled name of the specialized function,
4437+
// so that we can look for existing declarations.
4438+
String specMangledName;
4439+
String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef),
4440+
specDeclRef.Substitute(originalTable->supTypeDeclRef));
4441+
4442+
// TODO: This is a terrible linear search, and we should
4443+
// avoid it by building a dictionary ahead of time,
4444+
// as is being done for the `IRSpecContext` used above.
4445+
// We can probalby use the same basic context, actually.
4446+
auto module = originalTable->parentModule;
4447+
for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
4448+
{
4449+
if (gv->mangledName == specMangledName)
4450+
return (IRWitnessTable*)gv;
4451+
}
4452+
4453+
RefPtr<Substitutions> newSubst = cloneSubstitutionsForSpecialization(
4454+
sharedContext,
4455+
specDeclRef.substitutions,
4456+
originalTable->genericDecl);
4457+
4458+
IRGenericSpecContext context;
4459+
context.shared = sharedContext;
4460+
context.builder = &sharedContext->builderStorage;
4461+
context.subst = newSubst;
4462+
4463+
// TODO: other initialization is needed here...
4464+
4465+
auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable);
4466+
4467+
// Set up the clone to recognize that it is no longer generic
4468+
specTable->mangledName = specMangledName;
4469+
specTable->genericDecl = nullptr;
4470+
4471+
// Specialization of witness tables should trigger cascading specializations
4472+
// of involved functions.
4473+
for (auto entry : specTable->entries)
4474+
{
4475+
if (entry->satisfyingVal.usedValue->op == kIROp_Func)
4476+
{
4477+
IRFunc* func = (IRFunc*)entry->satisfyingVal.usedValue;
4478+
if (func->genericDecl)
4479+
entry->satisfyingVal.set(getSpecializedFunc(sharedContext, func, specDeclRef));
4480+
}
4481+
4482+
}
4483+
// We also need to make sure that we register this specialized
4484+
// function under its mangled name, so that later lookup
4485+
// steps will find it.
4486+
insertGlobalValueSymbol(sharedContext, specTable);
4487+
4488+
return specTable;
4489+
}
43974490

43984491
IRFunc* getSpecializedFunc(
43994492
IRSharedGenericSpecContext* sharedContext,
@@ -4415,9 +4508,9 @@ namespace Slang
44154508
// as is being done for the `IRSpecContext` used above.
44164509
// We can probalby use the same basic context, actually.
44174510
auto module = genericFunc->parentModule;
4418-
for(auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
4511+
for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
44194512
{
4420-
if(gv->mangledName == specMangledName)
4513+
if (gv->mangledName == specMangledName)
44214514
return (IRFunc*) gv;
44224515
}
44234516

@@ -4564,7 +4657,7 @@ namespace Slang
45644657
// specialization to perform.
45654658
auto func = workList[count-1];
45664659
workList.RemoveAt(count-1);
4567-
4660+
45684661
// We are going to go ahead and walk through
45694662
// all the instructions in this function,
45704663
// and look for `specialize` operations.
@@ -4596,32 +4689,40 @@ namespace Slang
45964689
// specialization here and now.
45974690
IRSpecialize* specInst = (IRSpecialize*) ii;
45984691

4599-
// We need to check that the value being specialized is
4600-
// a generic function.
4601-
auto genericVal = specInst->genericVal.usedValue;
4602-
if(genericVal->op != kIROp_Func)
4603-
continue;
4604-
auto genericFunc = (IRFunc*) genericVal;
4605-
if(!genericFunc->genericDecl)
4606-
continue;
4607-
46084692
// Now we extract the specialized decl-ref that will
46094693
// tell us how to specialize things.
4610-
auto specDeclRefVal = (IRDeclRef*) specInst->specDeclRefVal.usedValue;
4694+
auto specDeclRefVal = (IRDeclRef*)specInst->specDeclRefVal.usedValue;
46114695
auto specDeclRef = specDeclRefVal->declRef;
46124696

4613-
// Okay, we have a candidate for specialization here.
4614-
//
4615-
// We will first find or construct a specialized version
4616-
// of the callee funciton/
4617-
auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef);
4618-
//
4619-
// Then we will replace the use sites for the `specialize`
4620-
// instruction with uses of the specialized function.
4621-
//
4622-
specInst->replaceUsesWith(specFunc);
4623-
4624-
specInst->removeAndDeallocate();
4697+
// We need to specialize functions and witness tables
4698+
auto genericVal = specInst->genericVal.usedValue;
4699+
if (genericVal->op == kIROp_Func)
4700+
{
4701+
auto genericFunc = (IRFunc*)genericVal;
4702+
if (!genericFunc->genericDecl)
4703+
continue;
4704+
4705+
// Okay, we have a candidate for specialization here.
4706+
//
4707+
// We will first find or construct a specialized version
4708+
// of the callee funciton/
4709+
auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef);
4710+
//
4711+
// Then we will replace the use sites for the `specialize`
4712+
// instruction with uses of the specialized function.
4713+
//
4714+
specInst->replaceUsesWith(specFunc);
4715+
4716+
specInst->removeAndDeallocate();
4717+
}
4718+
else if (genericVal->op == kIROp_witness_table)
4719+
{
4720+
// specialize a witness table
4721+
auto originalTable = (IRWitnessTable*)genericVal;
4722+
auto specWitnessTable = specializeWitnessTable(sharedContext, originalTable, specDeclRef);
4723+
specInst->replaceUsesWith(specWitnessTable);
4724+
specInst->removeAndDeallocate();
4725+
}
46254726
}
46264727
break;
46274728

source/slang/ir.h

+1
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ void printSlangIRAssembly(StringBuilder& builder, IRModule* module);
471471
String getSlangIRAssembly(IRModule* module);
472472

473473
void dumpIR(IRModule* module);
474+
String dumpIRFunc(IRFunc* func);
474475

475476
}
476477

0 commit comments

Comments
 (0)