Skip to content

Commit 58a5db2

Browse files
committed
WIP: Switch WitnessTable to HOISTABLE
This commit deduplicates Witnessable by making it HOISTABLE. Three problems were observed when changed to HOISTABLE, and they are addressed in this commit. 1. A HOISTABLE IRInst is immutable once created. `SetOperand()` functions became unavailable. There were code that was modifying WitnessTable after creating it. This commit changed the behavior so that all operands are prepared before creating a new WitnessTable, and all operands are given to its constructor. 1. WitnessTable started to have duplicated decorations and children. Due to the nature of HOISTABLE, when create a new WitnessTable, a pre-existing WitnessTable is returned and reused. There isn't an easy way to tell if the WitnessTable is a brand new or reused at the moment. This commit assumes that it is a brand new only when it has no decorations and chilrend after its creation. 1. In `SimplifyIR()`, the order of children were slightly changed as a result of the optimization. The behavior was little different when WitnessTable became HOISTABLE. This resulted in an error where WitnessTable has a WitnessEntry pointing to an incorrect `IRSpecialize`. In order for it to function properly, `IRSpecialize` had to appear before `IRWitnessTable`.
1 parent 3c096a7 commit 58a5db2

6 files changed

+151
-65
lines changed

source/slang/slang-ir-autodiff.cpp

+20-7
Original file line numberDiff line numberDiff line change
@@ -3177,14 +3177,27 @@ struct AutoDiffPass : public InstPassBase
31773177
List<IRInst*> args;
31783178
for (auto param : genType->getParams())
31793179
args.add(param);
3180-
as<IRWitnessTable>(innerResult.diffWitness)
3181-
->setConcreteType((IRType*)builder.emitSpecializeInst(
3182-
builder.getTypeKind(),
3183-
originalType,
3184-
(UInt)args.getCount(),
3185-
args.getBuffer()));
3180+
3181+
auto concreteType = as<IRType>(builder.emitSpecializeInst(
3182+
builder.getTypeKind(),
3183+
originalType,
3184+
(UInt)args.getCount(),
3185+
args.getBuffer()));
3186+
3187+
auto witnessTableType = innerResult.diffWitness->getFullType();
3188+
auto newWitnessTable = builder.createWitnessTable(witnessTableType, concreteType);
3189+
3190+
// Copy all entries from the old witness table to the new one
3191+
for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
3192+
{
3193+
builder.createWitnessTableEntry(
3194+
newWitnessTable,
3195+
entry->getRequirementKey(),
3196+
entry->getSatisfyingVal());
3197+
}
3198+
31863199
result.diffWitness =
3187-
hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true);
3200+
hoistValueFromGeneric(builder, newWitnessTable, specInst, true);
31883201
}
31893202
return result;
31903203
}

source/slang/slang-ir-inst-defs.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ INST(GlobalConstant, globalConstant, 0, GLOBAL)
294294

295295
INST(StructKey, key, 0, GLOBAL)
296296
INST(GlobalGenericParam, global_generic_param, 0, GLOBAL)
297-
INST(WitnessTable, witness_table, 0, 0)
297+
INST(WitnessTable, witness_table, 0, HOISTABLE)
298298

299299
INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE)
300300

source/slang/slang-ir-insts.h

-2
Original file line numberDiff line numberDiff line change
@@ -2950,8 +2950,6 @@ struct IRWitnessTable : IRInst
29502950

29512951
IRType* getConcreteType() { return (IRType*)getOperand(0); }
29522952

2953-
void setConcreteType(IRType* t) { return setOperand(0, t); }
2954-
29552953
IR_LEAF_ISA(WitnessTable)
29562954
};
29572955

source/slang/slang-ir.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,21 @@ void addHoistableInst(IRBuilder* builder, IRInst* inst)
16941694
while (insertBeforeInst && insertBeforeInst->getOp() == kIROp_Param)
16951695
insertBeforeInst = insertBeforeInst->getNextInst();
16961696

1697+
if (inst->getOp() == kIROp_WitnessTable)
1698+
{
1699+
// WitnessTable may reference specialize inst-s from its WitnessEntry
1700+
// children. In this case, specialize insts must be cloned before the
1701+
// WitnessTable.
1702+
//
1703+
for (IRInst* iter = insertBeforeInst; iter; )
1704+
{
1705+
bool isSpecialize = (iter->getOp() == kIROp_Specialize);
1706+
iter = iter->getNextInst();
1707+
if (isSpecialize)
1708+
insertBeforeInst = iter;
1709+
}
1710+
}
1711+
16971712
// For instructions that will be placed at module scope,
16981713
// we don't care about relative ordering, but for everything
16991714
// else, we want to ensure that an instruction comes after

source/slang/slang-lower-to-ir.cpp

+82-55
Original file line numberDiff line numberDiff line change
@@ -8016,27 +8016,38 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
80168016
irSatisfyingWitnessTable = subBuilder->createWitnessTable(
80178017
irWitnessTableBaseType,
80188018
irWitnessTable->getConcreteType());
8019-
auto mangledName = getMangledNameForConformanceWitness(
8020-
subContext->astBuilder,
8021-
astReqWitnessTable->witnessedType,
8022-
astReqWitnessTable->baseType);
8023-
subBuilder->addExportDecoration(
8024-
irSatisfyingWitnessTable,
8025-
mangledName.getUnownedSlice());
8026-
if (isExportedType(astReqWitnessTable->witnessedType))
8019+
8020+
// TODO: When WitnessTable became HOISTABLE, we needed a way to avoid
8021+
// adding the same decorations and childrens redundantly.
8022+
// There isn't an easy way to tell if the returned WitnessTable is a
8023+
// brand new or a found one from the existing pool.
8024+
// The code below assumes that when there are any decoration or child,
8025+
// it is a pre-existed one.
8026+
//
8027+
if (irSatisfyingWitnessTable->getFirstDecorationOrChild() == nullptr)
80278028
{
8028-
subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable);
8029-
subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
8030-
}
8029+
auto mangledName = getMangledNameForConformanceWitness(
8030+
subContext->astBuilder,
8031+
astReqWitnessTable->witnessedType,
8032+
astReqWitnessTable->baseType);
8033+
subBuilder->addExportDecoration(
8034+
irSatisfyingWitnessTable,
8035+
mangledName.getUnownedSlice());
8036+
if (isExportedType(astReqWitnessTable->witnessedType))
8037+
{
8038+
subBuilder->addHLSLExportDecoration(irSatisfyingWitnessTable);
8039+
subBuilder->addKeepAliveDecoration(irSatisfyingWitnessTable);
8040+
}
80318041

8032-
// Recursively lower the sub-table.
8033-
lowerWitnessTable(
8034-
subContext,
8035-
astReqWitnessTable,
8036-
irSatisfyingWitnessTable,
8037-
mapASTToIRWitnessTable);
8042+
// Recursively lower the sub-table.
8043+
lowerWitnessTable(
8044+
subContext,
8045+
astReqWitnessTable,
8046+
irSatisfyingWitnessTable,
8047+
mapASTToIRWitnessTable);
80388048

8039-
irSatisfyingWitnessTable->moveToEnd();
8049+
irSatisfyingWitnessTable->moveToEnd();
8050+
}
80408051
}
80418052
irSatisfyingVal = irSatisfyingWitnessTable;
80428053
}
@@ -8145,58 +8156,74 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
81458156
//
81468157
auto irWitnessTableBaseType = lowerType(subContext, superType);
81478158

8148-
// Create the IR-level witness table
8149-
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr);
8150-
8151-
// Register the value now, rather than later, to avoid any possible infinite recursion.
8159+
// Register a dummy value to avoid infinite recursions.
8160+
// Without this, the call to lowerType() can get into an infinite recursion.
8161+
//
81528162
context->setGlobalValue(
81538163
inheritanceDecl,
8154-
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
8164+
LoweredValInfo::simple(findOuterMostGeneric(subBuilder->getInsertLoc().getParent())));
81558165

81568166
auto irSubType = lowerType(subContext, subType);
8157-
irWitnessTable->setConcreteType(irSubType);
81588167

8159-
// TODO(JS):
8160-
// Should the mangled name take part in obfuscation if enabled?
8168+
// Create the IR-level witness table
8169+
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType);
8170+
8171+
// TODO: When WitnessTable became HOISTABLE, we needed a way to avoid
8172+
// adding the same decorations and childrens redundantly.
8173+
// There isn't an easy way to tell if the returned WitnessTable is a
8174+
// brand new or a found one from the existing pool.
8175+
// The code below assumes that when there are any decoration or child,
8176+
// it is a pre-existed one.
8177+
//
8178+
if (irWitnessTable->getFirstDecorationOrChild() == nullptr)
8179+
{
8180+
// Override with the correct witness-table
8181+
context->setGlobalValue(
8182+
inheritanceDecl,
8183+
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
81618184

8162-
addLinkageDecoration(
8163-
context,
8164-
irWitnessTable,
8165-
inheritanceDecl,
8166-
mangledName.getUnownedSlice());
8185+
// TODO(JS):
8186+
// Should the mangled name take part in obfuscation if enabled?
81678187

8168-
// If the witness table is for a COM interface, always keep it alive.
8169-
if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
8170-
{
8171-
subBuilder->addHLSLExportDecoration(irWitnessTable);
8172-
}
8188+
addLinkageDecoration(
8189+
context,
8190+
irWitnessTable,
8191+
inheritanceDecl,
8192+
mangledName.getUnownedSlice());
81738193

8174-
for (auto mod : parentDecl->modifiers)
8175-
{
8176-
if (as<HLSLExportModifier>(mod))
8194+
// If the witness table is for a COM interface, always keep it alive.
8195+
if (irWitnessTableBaseType->findDecoration<IRComInterfaceDecoration>())
81778196
{
81788197
subBuilder->addHLSLExportDecoration(irWitnessTable);
8179-
subBuilder->addKeepAliveDecoration(irWitnessTable);
81808198
}
8181-
else if (as<AutoDiffBuiltinAttribute>(mod))
8199+
8200+
for (auto mod : parentDecl->modifiers)
81828201
{
8183-
subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
8202+
if (as<HLSLExportModifier>(mod))
8203+
{
8204+
subBuilder->addHLSLExportDecoration(irWitnessTable);
8205+
subBuilder->addKeepAliveDecoration(irWitnessTable);
8206+
}
8207+
else if (as<AutoDiffBuiltinAttribute>(mod))
8208+
{
8209+
subBuilder->addAutoDiffBuiltinDecoration(irWitnessTable);
8210+
}
81848211
}
8185-
}
81868212

8187-
// Make sure that all the entries in the witness table have been filled in,
8188-
// including any cases where there are sub-witness-tables for conformances
8189-
bool isExplicitExtern = false;
8190-
if (!isImportedDecl(context, parentDecl, isExplicitExtern))
8191-
{
8192-
Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
8193-
lowerWitnessTable(
8194-
subContext,
8195-
inheritanceDecl->witnessTable,
8196-
irWitnessTable,
8197-
mapASTToIRWitnessTable);
8213+
// Make sure that all the entries in the witness table have been filled in,
8214+
// including any cases where there are sub-witness-tables for conformances
8215+
bool isExplicitExtern = false;
8216+
if (!isImportedDecl(context, parentDecl, isExplicitExtern))
8217+
{
8218+
Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable;
8219+
lowerWitnessTable(
8220+
subContext,
8221+
inheritanceDecl->witnessTable,
8222+
irWitnessTable,
8223+
mapASTToIRWitnessTable);
8224+
}
8225+
irWitnessTable->moveToEnd();
81988226
}
8199-
irWitnessTable->moveToEnd();
82008227

82018228
return LoweredValInfo::simple(
82028229
finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//TEST:SIMPLE(filecheck=CHK):-stage compute -entry computeMain -target hlsl
2+
3+
//CHK: struct DiffPair_1
4+
//CHK-NOT: struct DiffPair_2
5+
6+
RWTexture2D<float> gOutputColor;
7+
8+
struct ShadingFrame : IDifferentiable
9+
{
10+
float3 T;
11+
}
12+
13+
[Differentiable]
14+
float computeRay()
15+
{
16+
float3 dir = 1.f;
17+
return dot(dir, dir);
18+
}
19+
20+
[Differentiable]
21+
float paramRay()
22+
{
23+
DifferentialPair<float> dpDir = fwd_diff(computeRay)();
24+
return dpDir.p;
25+
}
26+
27+
[Shader("compute")]
28+
[NumThreads(1, 1, 1)]
29+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
30+
{
31+
DifferentialPair<float> dpColor = fwd_diff(paramRay)();
32+
gOutputColor[0] = dpColor.p;
33+
}

0 commit comments

Comments
 (0)