Skip to content

Commit 0ca75fe

Browse files
committed
Dynamic dispatch for generic interface requirements.
-Lower interfaces into actual `IRInterfaceType` insts. -Lower `DeclRef<AssocTypeDecl>` into `IRAssociatedType` -Generate proper IRType for generic functions. -Add a test case exercising dynamic dispatching a generic static function through an associated type. -Bug fixes for the test case.
1 parent 3fe4f53 commit 0ca75fe

12 files changed

+593
-170
lines changed

source/slang/slang-emit-c-like.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,9 @@ List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi
236236
// Get a sorted list of entries using RequirementKeys defined in `interfaceType`.
237237
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
238238
{
239-
auto reqKey = cast<IRStructKey>(interfaceType->getOperand(i));
239+
auto reqEntry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
240240
IRWitnessTableEntry* entry = nullptr;
241-
if (witnessTableEntryDictionary.TryGetValue(reqKey, entry))
241+
if (witnessTableEntryDictionary.TryGetValue(reqEntry->getRequirementKey(), entry))
242242
{
243243
sortedWitnessTableEntries.add(entry);
244244
}
@@ -1962,6 +1962,10 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
19621962
are hashed with 'getStringHash' */
19631963
break;
19641964

1965+
case kIROp_undefined:
1966+
m_writer->emit(getName(inst));
1967+
break;
1968+
19651969
case kIROp_IntLit:
19661970
case kIROp_FloatLit:
19671971
case kIROp_BoolLit:
@@ -3554,6 +3558,11 @@ void CLikeSourceEmitter::emitGlobalInst(IRInst* inst)
35543558
are hashed with 'getStringHash' */
35553559
break;
35563560

3561+
case kIROp_InterfaceRequirementEntry:
3562+
// Don't emit anything for interface requirement at global level.
3563+
// They are handled in `emitInterface`.
3564+
break;
3565+
35573566
case kIROp_Func:
35583567
emitFunc((IRFunc*) inst);
35593568
break;
@@ -3610,6 +3619,10 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I
36103619
ensureInstOperand(ctx, inst->getFullType());
36113620

36123621
UInt operandCount = inst->operandCount;
3622+
auto requiredLevel = EmitAction::Definition;
3623+
if (inst->op == kIROp_InterfaceType)
3624+
requiredLevel = EmitAction::ForwardDeclaration;
3625+
36133626
for(UInt ii = 0; ii < operandCount; ++ii)
36143627
{
36153628
// TODO: there are some special cases we can add here,
@@ -3620,8 +3633,8 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I
36203633
// only need the type they point to to be forward-declared.
36213634
// Similarly, a `call` instruction only needs the callee
36223635
// to be forward-declared, etc.
3623-
3624-
ensureInstOperand(ctx, inst->getOperand(ii));
3636+
3637+
ensureInstOperand(ctx, inst->getOperand(ii), requiredLevel);
36253638
}
36263639

36273640
for(auto child : inst->getDecorationsAndChildren())

source/slang/slang-emit-cpp.cpp

+35-29
Original file line numberDiff line numberDiff line change
@@ -390,12 +390,27 @@ static UnownedStringSlice _getResourceTypePrefix(IROp op)
390390
}
391391
}
392392

393+
static bool isVoidPtrType(IRType* type)
394+
{
395+
auto ptrType = as<IRPtrType>(type);
396+
if (!ptrType) return false;
397+
return ptrType->getValueType()->op == kIROp_VoidType;
398+
}
399+
393400
SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out)
394401
{
395402
switch (type->op)
396403
{
397404
case kIROp_PtrType:
398405
{
406+
if (isVoidPtrType(type))
407+
{
408+
// A `void*` type will always emit as `void*`.
409+
// `void*` types are generated as a result of generics lowering
410+
// for dynamic dispatch.
411+
out << "void*";
412+
return SLANG_OK;
413+
}
399414
auto ptrType = static_cast<IRPtrType*>(type);
400415
SLANG_RETURN_ON_FAIL(calcTypeName(ptrType->getValueType(), target, out));
401416
// TODO(JS): It seems although it says it is a pointer, it can actually be output as a reference
@@ -494,7 +509,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
494509
// struct of function pointers corresponding to the interface type.
495510
auto witnessTableType = static_cast<IRWitnessTableType*>(type);
496511
auto baseType = cast<IRType>(witnessTableType->getOperand(0));
497-
emitType(baseType);
512+
SLANG_RETURN_ON_FAIL(calcTypeName(baseType, target, out));
498513
out << "*";
499514
return SLANG_OK;
500515
}
@@ -1591,8 +1606,7 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
15911606
{
15921607
auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0));
15931608
auto witnessTableItems = witnessTable->getChildren();
1594-
List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable);
1595-
_maybeEmitWitnessTableTypeDefinition(interfaceType, sortedWitnessTableEntries);
1609+
_maybeEmitWitnessTableTypeDefinition(interfaceType);
15961610

15971611
// Define a global variable for the witness table.
15981612
m_writer->emit("extern ");
@@ -1747,51 +1761,52 @@ void CPPSourceEmitter::emitInterface(IRInterfaceType* interfaceType)
17471761
/// acoording to the order defined by `interfaceType`.
17481762
///
17491763
void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition(
1750-
IRInterfaceType* interfaceType,
1751-
const List<IRWitnessTableEntry*>& sortedWitnessTableEntries)
1764+
IRInterfaceType* interfaceType)
17521765
{
17531766
m_writer->emit("struct ");
17541767
emitSimpleType(interfaceType);
17551768
m_writer->emit("\n{\n");
17561769
m_writer->indent();
1757-
for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++)
1770+
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
17581771
{
1759-
auto entry = sortedWitnessTableEntries[i];
1760-
if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get()))
1772+
auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
1773+
if (auto funcVal = as<IRFuncType>(entry->getRequirementVal()))
17611774
{
17621775
emitType(funcVal->getResultType());
17631776
m_writer->emit(" (KernelContext::*");
17641777
m_writer->emit(getName(entry->requirementKey.get()));
17651778
m_writer->emit(")");
17661779
m_writer->emit("(");
17671780
bool isFirstParam = true;
1768-
for (auto param : funcVal->getParams())
1781+
for (UInt p = 0; p < funcVal->getParamCount(); p++)
17691782
{
1783+
auto paramType = funcVal->getParamType(p);
1784+
// Ingore TypeType-typed parameters for now.
1785+
if (as<IRTypeType>(paramType))
1786+
continue;
1787+
17701788
if (!isFirstParam)
17711789
m_writer->emit(", ");
17721790
else
17731791
isFirstParam = false;
1774-
if (param->findDecoration<IRThisPointerDecoration>())
1792+
auto thisDecor = funcVal->findDecoration<IRThisPointerDecoration>();
1793+
if (thisDecor && cast<IRIntLit>(thisDecor->getOperand(0))->value.intVal == (IRIntegerValue)p)
17751794
{
1776-
m_writer->emit("void* ");
1777-
m_writer->emit(getName(param));
1795+
m_writer->emit("void* param");
1796+
m_writer->emit(p);
17781797
continue;
17791798
}
1780-
emitSimpleFuncParamImpl(param);
1799+
emitParamType(paramType, String("param") + String(p));
17811800
}
17821801
m_writer->emit(");\n");
17831802
}
1784-
else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal()))
1803+
else if (auto constraintInterfaceType = as<IRInterfaceType>(entry->getRequirementVal()))
17851804
{
1786-
emitType(as<IRType>(witnessTableVal->getOperand(0)));
1805+
emitType(constraintInterfaceType);
17871806
m_writer->emit("* ");
17881807
m_writer->emit(getName(entry->requirementKey.get()));
17891808
m_writer->emit(";\n");
17901809
}
1791-
else
1792-
{
1793-
// TODO: handle other witness table entry types.
1794-
}
17951810
}
17961811
m_writer->dedent();
17971812
m_writer->emit("};\n");
@@ -1990,23 +2005,14 @@ void CPPSourceEmitter::emitSimpleValueImpl(IRInst* inst)
19902005
}
19912006
}
19922007

1993-
static bool isVoidPtrType(IRType* type)
1994-
{
1995-
auto ptrType = as<IRPtrType>(type);
1996-
if (!ptrType) return false;
1997-
return ptrType->getValueType()->op == kIROp_VoidType;
1998-
}
1999-
20002008
void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
20012009
{
20022010
// Polymorphic types are already translated to void* type in
20032011
// lower-generics pass. However, the current emitting logic will
20042012
// emit "void&" instead of "void*" for pointer types.
20052013
// In the future, we will handle pointer types more properly,
20062014
// and this override logic will not be necessary.
2007-
// For now we special-case this scenario.
2008-
if (param->findDecoration<IRPolymorphicDecoration>() &&
2009-
isVoidPtrType(param->getDataType()))
2015+
if (isVoidPtrType(param->getDataType()))
20102016
{
20112017
m_writer->emit("void* ");
20122018
m_writer->emit(getName(param));

source/slang/slang-emit-cpp.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class CPPSourceEmitter: public CLikeSourceEmitter
8989
virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder);
9090

9191
// Emits a struct of function pointers defined in `interfaceType`.
92-
void _maybeEmitWitnessTableTypeDefinition(IRInterfaceType* interfaceType, const List<IRWitnessTableEntry*>& sortedWitnessTableEntries);
92+
void _maybeEmitWitnessTableTypeDefinition(IRInterfaceType* interfaceType);
9393
void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp);
9494

9595
void _emitForwardDeclarations(const List<EmitAction>& actions);

source/slang/slang-ir-inst-defs.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ INST(Nop, nop, 0, 0)
164164
// `field` instructions.
165165
//
166166
INST(StructType, struct, 0, PARENT)
167-
INST(InterfaceType, interface, 0, PARENT)
167+
INST(InterfaceType, interface, 0, 0)
168+
INST(AssociatedType, associated_type, 0, 0)
168169

169170
// A TypeType-typed IRValue represents a IRType.
170171
// It is used to represent a type parameter/argument in a generics.
@@ -223,6 +224,7 @@ INST(Call, call, 1, 0)
223224

224225

225226
INST(WitnessTableEntry, witness_table_entry, 2, 0)
227+
INST(InterfaceRequirementEntry, interface_req_entry, 2, 0)
226228

227229
INST(Param, param, 0, 0)
228230
INST(StructField, field, 2, 0)
@@ -507,14 +509,12 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
507509

508510
INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0)
509511

510-
/// A `[polymorphic]` decoration marks a function parameter that should translate to an abstract type
511-
/// e.g. (void*) that are casted to actual type before use. For example, a parameter of generic type
512-
/// is marked `[polymorphic]`, so that the code gen logic can emit it as a `void*` parameter,
513-
/// allowing the function to be used at sites that are agnostic of the actual object type.
514-
INST(PolymorphicDecoration, polymorphic, 0, 0)
515512

516513
/// A `[this_ptr]` decoration marks a function parameter that serves as `this` pointer.
517-
INST(ThisPointerDecoration, this_ptr, 0, 0)
514+
/// `[this_ptr]` decoration is also used to mark an `IRFunc` as a non-static function.
515+
/// The argument is an integer value that represents the index of the `this` parameter,
516+
/// which is always 0.
517+
INST(ThisPointerDecoration, this_ptr, 1, 0)
518518

519519

520520
/// A `[format(f)]` decoration specifies that the format of an image should be `f`

source/slang/slang-ir-insts.h

+15-10
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,6 @@ IR_SIMPLE_DECORATION(VulkanCallablePayloadDecoration)
165165
/// vulkan hit attributes, and should have a location assigned
166166
/// to it.
167167
IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration)
168-
169-
IR_SIMPLE_DECORATION(PolymorphicDecoration)
170168
IR_SIMPLE_DECORATION(ThisPointerDecoration)
171169

172170

@@ -410,9 +408,13 @@ struct IRLookupWitnessMethod : IRInst
410408
{
411409
IRUse witnessTable;
412410
IRUse requirementKey;
411+
IRUse interfaceType;
413412

414413
IRInst* getWitnessTable() { return witnessTable.get(); }
415414
IRInst* getRequirementKey() { return requirementKey.get(); }
415+
IRInst* getInterfaceType() { return interfaceType.get(); }
416+
417+
IR_LEAF_ISA(lookup_interface_method)
416418
};
417419

418420
struct IRLookupWitnessTable : IRInst
@@ -1675,7 +1677,8 @@ struct IRBuilder
16751677
IRInst* emitLookupInterfaceMethodInst(
16761678
IRType* type,
16771679
IRInst* witnessTableVal,
1678-
IRInst* interfaceMethodVal);
1680+
IRInst* interfaceMethodVal,
1681+
IRType* interfaceType);
16791682

16801683
IRInst* emitCallInst(
16811684
IRType* type,
@@ -1809,9 +1812,16 @@ struct IRBuilder
18091812
IRInst* requirementKey,
18101813
IRInst* satisfyingVal);
18111814

1815+
IRInterfaceRequirementEntry* createInterfaceRequirementEntry(
1816+
IRInst* requirementKey,
1817+
IRInst* requirementVal);
1818+
18121819
// Create an initially empty `struct` type.
18131820
IRStructType* createStructType();
18141821

1822+
// Create an IRType representing an `associatedtype` decl.
1823+
IRAssociatedType* createAssociatedType();
1824+
18151825
// Create an empty `interface` type.
18161826
IRInterfaceType* createInterfaceType(UInt operandCount, IRInst* const* operands);
18171827

@@ -2160,14 +2170,9 @@ struct IRBuilder
21602170
addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode)));
21612171
}
21622172

2163-
void addPolymorphicDecoration(IRInst* value)
2164-
{
2165-
addDecoration(value, kIROp_PolymorphicDecoration);
2166-
}
2167-
2168-
void addThisPointerDecoration(IRInst* value)
2173+
void addThisPointerDecoration(IRInst* value, int paramIndex)
21692174
{
2170-
addDecoration(value, kIROp_ThisPointerDecoration);
2175+
addDecoration(value, kIROp_ThisPointerDecoration, getIntValue(getIntType(), paramIndex));
21712176
}
21722177

21732178
void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0)

source/slang/slang-ir-link.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
228228
case kIROp_StructKey:
229229
case kIROp_GlobalGenericParam:
230230
case kIROp_WitnessTable:
231+
case kIROp_InterfaceType:
231232
case kIROp_TaggedUnionType:
232233
return cloneGlobalValue(this, originalValue);
233234

@@ -607,8 +608,7 @@ IRInterfaceType* cloneInterfaceTypeImpl(
607608
auto clonedInterface = builder->createInterfaceType(originalInterface->getOperandCount(), nullptr);
608609
for (UInt i = 0; i < originalInterface->getOperandCount(); i++)
609610
{
610-
auto clonedKey = findClonedValue(context, originalInterface->getOperand(i));
611-
SLANG_ASSERT(clonedKey);
611+
auto clonedKey = cloneValue(context, originalInterface->getOperand(i));
612612
clonedInterface->setOperand(i, clonedKey);
613613
}
614614
cloneSimpleGlobalValueImpl(context, originalInterface, originalValues, clonedInterface);
@@ -628,6 +628,7 @@ void cloneGlobalValueWithCodeCommon(
628628

629629
cloneDecorations(context, clonedValue, originalValue);
630630
cloneExtraDecorations(context, clonedValue, originalValues);
631+
clonedValue->setFullType((IRType*)cloneValue(context, originalValue->getFullType()));
631632

632633
// We will walk through the blocks of the function, and clone each of them.
633634
//

0 commit comments

Comments
 (0)