Skip to content

Commit 90444f8

Browse files
authored
Generate IRType for interfaces, and reference them as operand[0] in IRWitnessTable values (shader-slang#1387)
* Generate IRType for interfaces, and use them as the type of IRWitnessTable values. This results the following IR for the included test case: ``` [export("_S3tu010IInterface7Computep1pii")] let %1 : _ = key [export("_ST3tu010IInterface")] [nameHint("IInterface")] interface %IInterface : _(%1); [export("_S3tu04Impl7Computep1pii")] [nameHint("Impl.Compute")] func %Implx5FCompute : Func(Int, Int) { block %2( [nameHint("inVal")] param %inVal : Int): let %3 : Int = mul(%inVal, %inVal) return_val(%3) } [export("_SW3tu04Impl3tu010IInterface")] witness_table %4 : %IInterface { witness_table_entry(%1,%Implx5FCompute) } ``` * Fixes per code review comments. Moved interface type reference in IRWitnessTable from their type to operand[0]. * Fix typo in comment.
1 parent 36a06f1 commit 90444f8

13 files changed

+144
-51
lines changed

source/slang/slang-ast-support-types.h

+3
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,9 @@ namespace Slang
12851285
struct WitnessTable : RefObject
12861286
{
12871287
RequirementDictionary requirementDictionary;
1288+
1289+
// The type that the witness table witnesses conformance to (e.g. an Interface)
1290+
Type* baseType;
12881291
};
12891292

12901293
typedef Dictionary<unsigned int, NodeBase*> AttributeArgumentValueDict;

source/slang/slang-check-decl.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,7 @@ namespace Slang
15811581
if(!witnessTable)
15821582
{
15831583
witnessTable = new WitnessTable();
1584+
witnessTable->baseType = DeclRefType::create(m_astBuilder, interfaceDeclRef);
15841585
}
15851586
context->mapInterfaceToWitnessTable.Add(interfaceDeclRef, witnessTable);
15861587

@@ -2137,6 +2138,7 @@ namespace Slang
21372138
// let them define a tag value with the name `__Tag`).
21382139
//
21392140
RefPtr<WitnessTable> witnessTable = new WitnessTable();
2141+
witnessTable->baseType = enumConformanceDecl->base.type;
21402142
enumConformanceDecl->witnessTable = witnessTable;
21412143

21422144
Name* tagAssociatedTypeName = getSession()->getNameObj("__Tag");

source/slang/slang-emit-c-like.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
217217
outNumThreads[i] = decor ? Int(getIntVal(decor->getOperand(i))) : 1;
218218
}
219219
return decor;
220-
}
221-
220+
}
221+
222222
void CLikeSourceEmitter::_emitArrayType(IRArrayType* arrayType, EDeclarator* declarator)
223223
{
224224
EDeclarator arrayDeclarator;
@@ -265,6 +265,12 @@ void CLikeSourceEmitter::_emitType(IRType* type, EDeclarator* declarator)
265265

266266
}
267267

268+
void CLikeSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
269+
{
270+
SLANG_UNUSED(witnessTable);
271+
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "Unimplemented emit: IROpWitnessTable.");
272+
}
273+
268274
void CLikeSourceEmitter::emitTypeImpl(IRType* type, const StringSliceLoc* nameAndLoc)
269275
{
270276
if (nameAndLoc)
@@ -3516,6 +3522,10 @@ void CLikeSourceEmitter::emitGlobalInst(IRInst* inst)
35163522
emitStruct(cast<IRStructType>(inst));
35173523
break;
35183524

3525+
case kIROp_WitnessTable:
3526+
emitWitnessTable(cast<IRWitnessTable>(inst));
3527+
break;
3528+
35193529
default:
35203530
// We have an "ordinary" instruction at the global
35213531
// scope, and we should therefore emit it using the

source/slang/slang-emit-c-like.h

+2
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ class CLikeSourceEmitter: public RefObject
340340
// Again necessary for & prefix intrinsics. May be removable in the future
341341
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) = 0;
342342

343+
virtual void emitWitnessTable(IRWitnessTable* witnessTable);
344+
343345
virtual void handleCallExprDecorationsImpl(IRInst* funcValue) { SLANG_UNUSED(funcValue); }
344346

345347
virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) { SLANG_UNUSED(varDecl); SLANG_UNUSED(varType); return false; }

source/slang/slang-emit-cpp.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,11 @@ void CPPSourceEmitter::emitParamTypeImpl(IRType* type, String const& name)
15591559
emitType(type, name);
15601560
}
15611561

1562+
void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
1563+
{
1564+
SLANG_UNUSED(witnessTable);
1565+
}
1566+
15621567
bool CPPSourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType)
15631568
{
15641569
SLANG_UNUSED(varDecl);

source/slang/slang-emit-cpp.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class CPPSourceEmitter: public CLikeSourceEmitter
7575
virtual void emitSimpleFuncImpl(IRFunc* func) SLANG_OVERRIDE;
7676
virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec) SLANG_OVERRIDE;
7777
virtual void emitParamTypeImpl(IRType* type, String const& name) SLANG_OVERRIDE;
78-
78+
virtual void emitWitnessTable(IRWitnessTable* witnessTable) SLANG_OVERRIDE;
7979
virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE;
8080
virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
8181

source/slang/slang-emit.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,9 @@ SlangResult emitEntryPointSourceFromIR(
698698
//
699699
// TODO: do we want to emit directly from IR, or translate the
700700
// IR back into AST for emission?
701+
#if 0
702+
dumpIR(compileRequest, irModule, "PRE-EMIT");
703+
#endif
701704
sourceEmitter->emitModule(irModule);
702705
}
703706

source/slang/slang-ir-insts.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -1790,7 +1790,10 @@ struct IRBuilder
17901790
IRType* valueType);
17911791
IRGlobalParam* createGlobalParam(
17921792
IRType* valueType);
1793-
IRWitnessTable* createWitnessTable();
1793+
1794+
/// Creates an IRWitnessTable value.
1795+
/// @param baseType: The comformant-to type of this witness.
1796+
IRWitnessTable* createWitnessTable(IRType* baseType);
17941797
IRWitnessTableEntry* createWitnessTableEntry(
17951798
IRWitnessTable* witnessTable,
17961799
IRInst* requirementKey,
@@ -1800,7 +1803,7 @@ struct IRBuilder
18001803
IRStructType* createStructType();
18011804

18021805
// Create an empty `interface` type.
1803-
IRInterfaceType* createInterfaceType();
1806+
IRInterfaceType* createInterfaceType(UInt operandCount, IRInst* const* operands);
18041807

18051808
// Create a global "key" to use for indexing into a `struct` type.
18061809
IRStructKey* createStructKey();

source/slang/slang-ir-link.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,12 @@ IRWitnessTable* cloneWitnessTableImpl(
567567
IRWitnessTable* dstTable = nullptr,
568568
bool registerValue = true)
569569
{
570-
auto clonedTable = dstTable ? dstTable : builder->createWitnessTable();
570+
IRWitnessTable* clonedTable = dstTable;
571+
if (!clonedTable)
572+
{
573+
auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getOperand(0)));
574+
clonedTable = builder->createWitnessTable(clonedBaseType);
575+
}
571576
cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
572577
return clonedTable;
573578
}
@@ -599,7 +604,13 @@ IRInterfaceType* cloneInterfaceTypeImpl(
599604
IRInterfaceType* originalInterface,
600605
IROriginalValuesForClone const& originalValues)
601606
{
602-
auto clonedInterface = builder->createInterfaceType();
607+
auto clonedInterface = builder->createInterfaceType(originalInterface->getOperandCount(), nullptr);
608+
for (UInt i = 0; i < originalInterface->getOperandCount(); i++)
609+
{
610+
auto clonedKey = findClonedValue(context, originalInterface->getOperand(i));
611+
SLANG_ASSERT(clonedKey);
612+
clonedInterface->setOperand(i, clonedKey);
613+
}
603614
cloneSimpleGlobalValueImpl(context, originalInterface, originalValues, clonedInterface);
604615
return clonedInterface;
605616
}

source/slang/slang-ir.cpp

+46-36
Original file line numberDiff line numberDiff line change
@@ -2770,12 +2770,13 @@ namespace Slang
27702770
return inst;
27712771
}
27722772

2773-
IRWitnessTable* IRBuilder::createWitnessTable()
2773+
IRWitnessTable* IRBuilder::createWitnessTable(IRType* baseType)
27742774
{
27752775
IRWitnessTable* witnessTable = createInst<IRWitnessTable>(
27762776
this,
27772777
kIROp_WitnessTable,
2778-
nullptr);
2778+
nullptr,
2779+
baseType);
27792780
addGlobalValue(this, witnessTable);
27802781
return witnessTable;
27812782
}
@@ -2810,12 +2811,14 @@ namespace Slang
28102811
return structType;
28112812
}
28122813

2813-
IRInterfaceType* IRBuilder::createInterfaceType()
2814+
IRInterfaceType* IRBuilder::createInterfaceType(UInt operandCount, IRInst* const* operands)
28142815
{
28152816
IRInterfaceType* interfaceType = createInst<IRInterfaceType>(
28162817
this,
28172818
kIROp_InterfaceType,
2818-
nullptr);
2819+
nullptr,
2820+
operandCount,
2821+
operands);
28192822
addGlobalValue(this, interfaceType);
28202823
return interfaceType;
28212824
}
@@ -4209,6 +4212,42 @@ namespace Slang
42094212
dump(context, "}");
42104213
}
42114214

4215+
static void dumpInstOperandList(
4216+
IRDumpContext* context,
4217+
IRInst* inst)
4218+
{
4219+
UInt argCount = inst->getOperandCount();
4220+
4221+
if (argCount == 0)
4222+
return;
4223+
4224+
UInt ii = 0;
4225+
4226+
// Special case: make printing of `call` a bit
4227+
// nicer to look at
4228+
if (inst->op == kIROp_Call && argCount > 0)
4229+
{
4230+
dump(context, " ");
4231+
auto argVal = inst->getOperand(ii++);
4232+
dumpOperand(context, argVal);
4233+
}
4234+
4235+
bool first = true;
4236+
dump(context, "(");
4237+
for (; ii < argCount; ++ii)
4238+
{
4239+
if (!first)
4240+
dump(context, ", ");
4241+
4242+
auto argVal = inst->getOperand(ii);
4243+
4244+
dumpOperand(context, argVal);
4245+
4246+
first = false;
4247+
}
4248+
4249+
dump(context, ")");
4250+
}
42124251

42134252
void dumpIRWitnessTableEntry(
42144253
IRDumpContext* context,
@@ -4234,6 +4273,8 @@ namespace Slang
42344273

42354274
dumpInstTypeClause(context, inst->getFullType());
42364275

4276+
dumpInstOperandList(context, inst);
4277+
42374278
if (!inst->getFirstChild())
42384279
{
42394280
// Empty.
@@ -4321,38 +4362,7 @@ namespace Slang
43214362

43224363
dump(context, opInfo.name);
43234364

4324-
UInt argCount = inst->getOperandCount();
4325-
4326-
if(argCount == 0)
4327-
return;
4328-
4329-
UInt ii = 0;
4330-
4331-
// Special case: make printing of `call` a bit
4332-
// nicer to look at
4333-
if (inst->op == kIROp_Call && argCount > 0)
4334-
{
4335-
dump(context, " ");
4336-
auto argVal = inst->getOperand(ii++);
4337-
dumpOperand(context, argVal);
4338-
}
4339-
4340-
bool first = true;
4341-
dump(context, "(");
4342-
for (; ii < argCount; ++ii)
4343-
{
4344-
if (!first)
4345-
dump(context, ", ");
4346-
4347-
auto argVal = inst->getOperand(ii);
4348-
4349-
dumpOperand(context, argVal);
4350-
4351-
first = false;
4352-
}
4353-
4354-
dump(context, ")");
4355-
4365+
dumpInstOperandList(context, inst);
43564366
}
43574367

43584368
static void dumpInstBody(

source/slang/slang-lower-to-ir.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
11051105
UNREACHABLE_RETURN(LoweredValInfo());
11061106
}
11071107

1108-
auto irWitnessTable = getBuilder()->createWitnessTable();
1108+
auto irWitnessTableBaseType = lowerType(context, supDeclRefType);
1109+
auto irWitnessTable = getBuilder()->createWitnessTable(irWitnessTableBaseType);
11091110

11101111
// Now we will iterate over the requirements (members) of the
11111112
// interface and try to synthesize an appropriate value for each.
@@ -4524,7 +4525,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
45244525
if(!mapASTToIRWitnessTable.TryGetValue(astReqWitnessTable, irSatisfyingWitnessTable))
45254526
{
45264527
// Need to construct a sub-witness-table
4527-
irSatisfyingWitnessTable = subBuilder->createWitnessTable();
4528+
auto irWitnessTableBaseType = lowerType(subContext, astReqWitnessTable->baseType);
4529+
irSatisfyingWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType);
45284530

45294531
// Recursively lower the sub-table.
45304532
lowerWitnessTable(
@@ -4637,10 +4639,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
46374639
// and we need those parameters to lower as references to
46384640
// the parameters of our IR-level generic.
46394641
//
4640-
lowerType(subContext, superType);
4642+
auto irWitnessTableBaseType = lowerType(subContext, superType);
46414643

46424644
// Create the IR-level witness table
4643-
auto irWitnessTable = subBuilder->createWitnessTable();
4645+
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType);
46444646
addLinkageDecoration(context, irWitnessTable, inheritanceDecl, mangledName.getUnownedSlice());
46454647

46464648
// Register the value now, rather than later, to avoid any possible infinite recursion.
@@ -5243,9 +5245,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
52435245
// a witness table for the interface type's conformance
52445246
// to its own interface.
52455247
//
5248+
List<IRStructKey*> requirementKeys;
52465249
for (auto requirementDecl : decl->members)
52475250
{
5248-
getInterfaceRequirementKey(requirementDecl);
5251+
requirementKeys.add(getInterfaceRequirementKey(requirementDecl));
52495252

52505253
// As a special case, any type constraints placed
52515254
// on an associated type will *also* need to be turned
@@ -5254,7 +5257,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
52545257
{
52555258
for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>())
52565259
{
5257-
getInterfaceRequirementKey(constraintDecl);
5260+
requirementKeys.add(getInterfaceRequirementKey(constraintDecl));
52585261
}
52595262
}
52605263
}
@@ -5267,11 +5270,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
52675270
// Emit any generics that should wrap the actual type.
52685271
emitOuterGenerics(subContext, decl, decl);
52695272

5270-
IRInterfaceType* irInterface = subBuilder->createInterfaceType();
5273+
IRInterfaceType* irInterface = subBuilder->createInterfaceType(
5274+
requirementKeys.getCount(),
5275+
reinterpret_cast<IRInst**>(requirementKeys.getBuffer()));
52715276
addNameHint(context, irInterface, decl);
52725277
addLinkageDecoration(context, irInterface, decl);
52735278
subBuilder->setInsertInto(irInterface);
5274-
52755279
// TODO: are there any interface members that should be
52765280
// nested inside the interface type itself?
52775281

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//TEST_IGNORE_FILE
2+
//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code -xslang -dump-ir
3+
4+
// Test basic dynamic dispatch code gen
5+
6+
interface IInterface
7+
{
8+
static int Compute(int inVal);
9+
};
10+
11+
int GenericCompute<T:IInterface>(int inVal)
12+
{
13+
return T.Compute(inVal);
14+
}
15+
16+
struct Impl : IInterface
17+
{
18+
static int Compute(int inVal) { return inVal * inVal; }
19+
};
20+
21+
int test(int inVal)
22+
{
23+
return GenericCompute<Impl>(inVal);
24+
}
25+
26+
//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
27+
RWStructuredBuffer<int> outputBuffer : register(u0);
28+
29+
[numthreads(4, 1, 1)]
30+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
31+
{
32+
uint tid = dispatchThreadID.x;
33+
int inVal = outputBuffer[tid];
34+
int outVal = test(inVal);
35+
outputBuffer[tid] = outVal;
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
0
2+
1
3+
4
4+
9

0 commit comments

Comments
 (0)