Skip to content

Commit 8daafcc

Browse files
csyongheTim Foley
authored and
Tim Foley
committed
bruteforce implementation of witness table resolution for associated (shader-slang#358)
1 parent 3d435f7 commit 8daafcc

16 files changed

+421
-149
lines changed

source/slang/check.cpp

+109-112
Large diffs are not rendered by default.

source/slang/decl-defs.h

+22-7
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,30 @@ SIMPLE_SYNTAX_CLASS(ClassDecl, AggTypeDecl)
9090
// An interface which other types can conform to
9191
SIMPLE_SYNTAX_CLASS(InterfaceDecl, AggTypeDecl)
9292

93+
ABSTRACT_SYNTAX_CLASS(TypeConstraintDecl, Decl)
94+
RAW(
95+
virtual TypeExp& getSup() = 0;
96+
)
97+
END_SYNTAX_CLASS()
98+
9399
// A kind of pseudo-member that represents an explicit
94100
// or implicit inheritance relationship.
95101
//
96-
SYNTAX_CLASS(InheritanceDecl, Decl)
97-
// The type expression as written
102+
SYNTAX_CLASS(InheritanceDecl, TypeConstraintDecl)
103+
// The type expression as written
98104
SYNTAX_FIELD(TypeExp, base)
99105

100-
RAW(
106+
RAW(
101107
// After checking, this dictionary will map members
102108
// required by the base type to their concrete
103109
// implementations in the type that contains
104110
// this inheritance declaration.
105-
Dictionary<DeclRef<Decl>, Decl*> requirementWitnesses;
106-
)
107-
111+
Dictionary<DeclRef<Decl>, DeclRef<Decl>> requirementWitnesses;
112+
virtual TypeExp& getSup() override
113+
{
114+
return base;
115+
}
116+
)
108117
END_SYNTAX_CLASS()
109118

110119
// TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance
@@ -216,13 +225,19 @@ SYNTAX_CLASS(GenericTypeParamDecl, SimpleTypeDecl)
216225
END_SYNTAX_CLASS()
217226

218227
// A constraint placed as part of a generic declaration
219-
SYNTAX_CLASS(GenericTypeConstraintDecl, Decl)
228+
SYNTAX_CLASS(GenericTypeConstraintDecl, TypeConstraintDecl)
220229
// A type constraint like `T : U` is constraining `T` to be "below" `U`
221230
// on a lattice of types. This may not be a subtyping relationship
222231
// per se, but it makes sense to use that terminology here, so we
223232
// think of these fields as the sub-type and sup-ertype, respectively.
224233
SYNTAX_FIELD(TypeExp, sub)
225234
SYNTAX_FIELD(TypeExp, sup)
235+
RAW(
236+
virtual TypeExp& getSup() override
237+
{
238+
return sup;
239+
}
240+
)
226241
END_SYNTAX_CLASS()
227242

228243
SIMPLE_SYNTAX_CLASS(GenericValueParamDecl, VarDeclBase)

source/slang/ir-inst-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ INST(decl_ref, decl_ref, 0, 0)
9898

9999
INST(specialize, specialize, 2, 0)
100100
INST(lookup_interface_method, lookup_interface_method, 2, 0)
101+
INST(lookup_witness_table, lookup_witness_table, 2, 0)
101102

102103
INST(Construct, construct, 0, 0)
103104
INST(Call, call, 1, 0)

source/slang/ir-insts.h

+23-1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ struct IRLookupWitnessMethod : IRInst
8484
IRUse requirementDeclRef;
8585
};
8686

87+
struct IRLookupWitnessTable : IRInst
88+
{
89+
IRUse sourceType;
90+
IRUse interfaceType;
91+
};
92+
8793
//
8894

8995
struct IRCall : IRInst
@@ -309,6 +315,14 @@ struct IRWitnessTable : IRGlobalValue
309315
IRValueList<IRWitnessTableEntry> entries;
310316
};
311317

318+
// An abstract witness table is a global value that
319+
// represents an inheritance relationship that can't
320+
// be resolved to a witness table at IR-generation time.
321+
struct IRAbstractWitness : IRGlobalValue
322+
{
323+
RefPtr<SubtypeWitness> witness;
324+
DeclRef<Decl> subTypeDeclRef, supTypeDeclRef;
325+
};
312326

313327

314328
// Description of an instruction to be used for global value numbering
@@ -402,6 +416,15 @@ struct IRBuilder
402416
DeclRef<Decl> witnessTableDeclRef,
403417
DeclRef<Decl> interfaceMethodDeclRef);
404418

419+
IRValue* emitLookupInterfaceMethodInst(
420+
IRType* type,
421+
IRValue* witnessTableVal,
422+
DeclRef<Decl> interfaceMethodDeclRef);
423+
424+
IRValue* emitFindWitnessTable(
425+
DeclRef<Decl> baseTypeDeclRef,
426+
IRType* interfaceType);
427+
405428
IRInst* emitCallInst(
406429
IRType* type,
407430
IRValue* func,
@@ -424,7 +447,6 @@ struct IRBuilder
424447
IRFunc* createFunc();
425448
IRGlobalVar* createGlobalVar(
426449
IRType* valueType);
427-
IRWitnessTable* createWitnessTable(Dictionary<DeclRef<Decl>, Decl*> & witnesses);
428450
IRWitnessTable* createWitnessTable();
429451
IRWitnessTableEntry* createWitnessTableEntry(
430452
IRWitnessTable* witnessTable,

source/slang/ir.cpp

+71-7
Original file line numberDiff line numberDiff line change
@@ -678,11 +678,38 @@ namespace Slang
678678
DeclRef<Decl> interfaceMethodDeclRef)
679679
{
680680
auto witnessTableVal = getDeclRefVal(witnessTableDeclRef);
681-
auto interfaceMethodVal = getDeclRefVal(interfaceMethodDeclRef);
681+
DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef;
682+
removeSubstDeclRef.substitutions = nullptr;
683+
auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef);
682684
return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal);
683685
}
684686

687+
IRValue* IRBuilder::emitLookupInterfaceMethodInst(
688+
IRType* type,
689+
IRValue* witnessTableVal,
690+
DeclRef<Decl> interfaceMethodDeclRef)
691+
{
692+
DeclRef<Decl> removeSubstDeclRef = interfaceMethodDeclRef;
693+
removeSubstDeclRef.substitutions = nullptr;
694+
auto interfaceMethodVal = getDeclRefVal(removeSubstDeclRef);
695+
return emitLookupInterfaceMethodInst(type, witnessTableVal, interfaceMethodVal);
696+
}
685697

698+
IRValue* IRBuilder::emitFindWitnessTable(
699+
DeclRef<Decl> baseTypeDeclRef,
700+
IRType* interfaceType)
701+
{
702+
auto interfaceTypeDeclRef = interfaceType->AsDeclRefType();
703+
SLANG_ASSERT(interfaceTypeDeclRef);
704+
auto inst = createInst<IRLookupWitnessTable>(
705+
this,
706+
kIROp_lookup_witness_table,
707+
interfaceType,
708+
getDeclRefVal(baseTypeDeclRef),
709+
getDeclRefVal(interfaceTypeDeclRef->declRef));
710+
addInst(inst);
711+
return inst;
712+
}
686713

687714
IRInst* IRBuilder::emitCallInst(
688715
IRType* type,
@@ -3200,7 +3227,6 @@ namespace Slang
32003227
Dictionary<String, VarLayout*> globalVarLayouts;
32013228

32023229
RefPtr<GlobalGenericParamSubstitution> subst;
3203-
32043230
// Override the "maybe clone" logic so that we always clone
32053231
virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
32063232

@@ -3228,6 +3254,7 @@ namespace Slang
32283254
return val->Substitute(subst);
32293255
}
32303256

3257+
32313258
IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
32323259
{
32333260
switch (originalValue->op)
@@ -3261,14 +3288,17 @@ namespace Slang
32613288
case kIROp_decl_ref:
32623289
{
32633290
IRDeclRef* od = (IRDeclRef*)originalValue;
3291+
auto newDeclRef = od->declRef;
32643292

32653293
// if the declRef is one of the __generic_param decl being substituted by subst
32663294
// return the substituted decl
32673295
if (subst)
32683296
{
3269-
if (od->declRef.getDecl() == subst->paramDecl)
3297+
int diff = 0;
3298+
newDeclRef = od->declRef.SubstituteImpl(subst, &diff);
3299+
if (newDeclRef.getDecl() == subst->paramDecl)
32703300
return builder->getTypeVal(subst->actualType.As<Type>());
3271-
else if (auto genConstraint = od->declRef.As<GenericTypeConstraintDecl>())
3301+
else if (auto genConstraint = newDeclRef.As<GenericTypeConstraintDecl>())
32723302
{
32733303
// a decl-ref to GenericTypeConstraintDecl as a result of
32743304
// referencing a generic parameter type should be replaced with
@@ -3288,7 +3318,7 @@ namespace Slang
32883318
}
32893319
}
32903320
}
3291-
auto declRef = maybeCloneDeclRef(od->declRef);
3321+
auto declRef = maybeCloneDeclRef(newDeclRef);
32923322
return builder->getDeclRefVal(declRef);
32933323
}
32943324
break;
@@ -3641,6 +3671,14 @@ namespace Slang
36413671
// and their instructions.
36423672
cloneFunctionCommon(context, clonedFunc, originalFunc);
36433673

3674+
//// for now, clone all unreferenced witness tables
3675+
//for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
3676+
// gv; gv = gv->getNextValue())
3677+
//{
3678+
// if (gv->op == kIROp_witness_table)
3679+
// cloneGlobalValue(context, (IRWitnessTable*)gv);
3680+
//}
3681+
36443682
// We need to attach the layout information for
36453683
// the entry point to this declaration, so that
36463684
// we can use it to inform downstream code emit.
@@ -4048,7 +4086,7 @@ namespace Slang
40484086
globalVar = globalVar->getNextValue();
40494087
}
40504088
SLANG_ASSERT(table);
4051-
table = cloneWitnessTableWithoutRegistering(context, (IRWitnessTable*)(table));
4089+
table = cloneGlobalValue(context, (IRWitnessTable*)(table));
40524090
IRProxyVal * tableVal = new IRProxyVal();
40534091
tableVal->inst.init(nullptr, table);
40544092
paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal));
@@ -4661,6 +4699,16 @@ namespace Slang
46614699
sharedContext->workList.Add(func);
46624700
}
46634701

4702+
// Build dictionary for witness tables
4703+
Dictionary<String, IRWitnessTable*> witnessTables;
4704+
for (auto gv = module->getFirstGlobalValue();
4705+
gv;
4706+
gv = gv->getNextValue())
4707+
{
4708+
if (gv->op == kIROp_witness_table)
4709+
witnessTables.AddIfNotExists(gv->mangledName, (IRWitnessTable*)gv);
4710+
}
4711+
46644712
// Now that we have our work list, we are going to
46654713
// process it until it goes empty. Along the way
46664714
// we may specialize a function and thus create
@@ -4738,12 +4786,28 @@ namespace Slang
47384786
// specialize a witness table
47394787
auto originalTable = (IRWitnessTable*)genericVal;
47404788
auto specWitnessTable = specializeWitnessTable(sharedContext, originalTable, specDeclRef);
4789+
witnessTables.AddIfNotExists(specWitnessTable->mangledName, specWitnessTable);
47414790
specInst->replaceUsesWith(specWitnessTable);
47424791
specInst->removeAndDeallocate();
47434792
}
47444793
}
47454794
break;
4746-
4795+
case kIROp_lookup_witness_table:
4796+
{
4797+
// try find concrete witness table from global scope
4798+
IRLookupWitnessTable* lookupInst = (IRLookupWitnessTable*)ii;
4799+
IRWitnessTable* witnessTable = nullptr;
4800+
auto srcDeclRef = ((IRDeclRef*)lookupInst->sourceType.usedValue)->declRef;
4801+
auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.usedValue)->declRef;
4802+
auto mangledName = getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef);
4803+
witnessTables.TryGetValue(mangledName, witnessTable);
4804+
if (witnessTable)
4805+
{
4806+
lookupInst->replaceUsesWith(witnessTable);
4807+
lookupInst->removeAndDeallocate();
4808+
}
4809+
}
4810+
break;
47474811
case kIROp_lookup_interface_method:
47484812
{
47494813
// We have a `lookup_interface_method` instruction,

source/slang/lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ void lookUpMemberImpl(
412412
auto declRef = declRefType->declRef;
413413
if (declRef.As<AssocTypeDecl>() || declRef.As<GlobalGenericParamDecl>())
414414
{
415-
for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(declRef.As<ContainerDecl>()))
415+
for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(declRef.As<ContainerDecl>()))
416416
{
417417
// The super-type in the constraint (e.g., `Foo` in `T : Foo`)
418418
// will tell us a type we should use for lookup.

source/slang/lower-to-ir.cpp

+57-7
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,52 @@ LoweredValInfo emitPostOp(
470470
return LoweredValInfo::ptr(argPtr);
471471
}
472472

473+
IRValue* findWitnessTable(
474+
IRGenContext* context,
475+
DeclRef<Decl> declRef);
476+
477+
LoweredValInfo emitWitnessTableRef(
478+
IRGenContext* context,
479+
Expr* expr)
480+
{
481+
if (auto mbrExpr = dynamic_cast<MemberExpr*>(expr))
482+
{
483+
if (auto inheritanceDeclRef = mbrExpr->declRef.As<InheritanceDecl>())
484+
{
485+
if (inheritanceDeclRef.getDecl()->ParentDecl->As<InterfaceDecl>() || inheritanceDeclRef.getDecl()->ParentDecl->As<AssocTypeDecl>())
486+
{
487+
RefPtr<Type> exprType = nullptr;
488+
if (auto tt = mbrExpr->BaseExpression->type->As<TypeType>())
489+
exprType = tt->type;
490+
else
491+
exprType = mbrExpr->BaseExpression->type;
492+
auto declRefType = exprType->GetCanonicalType()->AsDeclRefType();
493+
SLANG_ASSERT(declRefType);
494+
IRValue* witnessTableVal = nullptr;
495+
DeclRef<Decl> srcDeclRef = declRefType->declRef;
496+
if (!declRefType->declRef.As<AssocTypeDecl>())
497+
{
498+
// if we are referring to an actual type, don't include substitution
499+
// and generate specialize instruction
500+
srcDeclRef.substitutions = nullptr;
501+
}
502+
witnessTableVal = context->irBuilder->emitFindWitnessTable(srcDeclRef, inheritanceDeclRef.getDecl()->base.type);
503+
return maybeEmitSpecializeInst(context, LoweredValInfo::simple(witnessTableVal), declRefType->declRef);
504+
}
505+
else if (inheritanceDeclRef.getDecl()->ParentDecl->As<AggTypeDeclBase>())
506+
{
507+
return LoweredValInfo::simple(findWitnessTable(context, inheritanceDeclRef));
508+
}
509+
510+
}
511+
else if (auto genConstraintDeclRef = mbrExpr->declRef.As<GenericTypeConstraintDecl>())
512+
{
513+
return LoweredValInfo::simple(context->irBuilder->getDeclRefVal(genConstraintDeclRef));
514+
}
515+
}
516+
SLANG_UNEXPECTED("unknown witness table expression");
517+
}
518+
473519
// Emit a reference to a function, where we have concluded
474520
// that the original AST referenced `funcDeclRef`. The
475521
// optional expression `funcExpr` can provide additional
@@ -494,7 +540,7 @@ LoweredValInfo emitFuncRef(
494540
if(auto baseMemberExpr = baseExpr.As<MemberExpr>())
495541
{
496542
auto baseMemberDeclRef = baseMemberExpr->declRef;
497-
if(auto baseConstraintDeclRef = baseMemberDeclRef.As<GenericTypeConstraintDecl>())
543+
if(auto baseConstraintDeclRef = baseMemberDeclRef.As<TypeConstraintDecl>())
498544
{
499545
// We are calling a method "through" a generic type
500546
// parameter that was constrained to some type.
@@ -505,10 +551,10 @@ LoweredValInfo emitFuncRef(
505551
// find the corresponding member on our chosen type.
506552

507553
RefPtr<Type> type = funcExpr->type;
508-
554+
auto loweredBaseWitnessTable = emitWitnessTableRef(context, baseMemberExpr);
509555
auto loweredVal = LoweredValInfo::simple(context->irBuilder->emitLookupInterfaceMethodInst(
510556
type,
511-
baseMemberDeclRef,
557+
loweredBaseWitnessTable.val,
512558
funcDeclRef));
513559
return maybeEmitSpecializeInst(context, loweredVal, funcDeclRef);
514560
}
@@ -1184,7 +1230,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
11841230
boundMemberInfo->declRef = callableDeclRef;
11851231
return LoweredValInfo::boundMember(boundMemberInfo);
11861232
}
1187-
else if(auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>())
1233+
else if(auto constraintDeclRef = declRef.As<TypeConstraintDecl>())
11881234
{
11891235
// The code is making use of a "witness" that a value of
11901236
// some generic type conforms to an interface.
@@ -2770,10 +2816,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
27702816
for (auto entry : inheritanceDecl->requirementWitnesses)
27712817
{
27722818
auto requiredMemberDeclRef = entry.Key;
2773-
auto satisfyingMemberDecl = entry.Value;
2774-
2819+
auto satisfyingMemberDeclRef = entry.Value;
2820+
27752821
auto irRequirement = context->irBuilder->getDeclRefVal(requiredMemberDeclRef);
2776-
auto irSatisfyingVal = getSimpleVal(context, ensureDecl(context, satisfyingMemberDecl));
2822+
IRValue* irSatisfyingVal = nullptr;
2823+
if (satisfyingMemberDeclRef.As<GenericTypeConstraintDecl>())
2824+
irSatisfyingVal = context->irBuilder->getDeclRefVal(satisfyingMemberDeclRef);
2825+
else
2826+
irSatisfyingVal = getSimpleVal(context, ensureDecl(context, satisfyingMemberDeclRef));
27772827

27782828
context->irBuilder->createWitnessTableEntry(
27792829
witnessTable,

source/slang/parser.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2279,7 +2279,7 @@ namespace Slang
22792279
auto nameToken = parser->ReadToken(TokenType::Identifier);
22802280
assocTypeDecl->nameAndLoc = NameLoc(nameToken);
22812281
assocTypeDecl->loc = nameToken.loc;
2282-
parseOptionalGenericConstraints(parser, assocTypeDecl);
2282+
parseOptionalInheritanceClause(parser, assocTypeDecl);
22832283
parser->ReadToken(TokenType::Semicolon);
22842284
return assocTypeDecl;
22852285
}

0 commit comments

Comments
 (0)