Skip to content

Commit bdb3ccf

Browse files
committed
Make IRWitnessTable HOISTABLE
This commit is to remove the duplications of IRWitnessTable by making it HOISTABLE.
1 parent 9580e31 commit bdb3ccf

7 files changed

+71
-16
lines changed

source/slang/slang-ir-autodiff.cpp

+20-7
Original file line numberDiff line numberDiff line change
@@ -3179,14 +3179,27 @@ struct AutoDiffPass : public InstPassBase
31793179
List<IRInst*> args;
31803180
for (auto param : genType->getParams())
31813181
args.add(param);
3182-
as<IRWitnessTable>(innerResult.diffWitness)
3183-
->setConcreteType((IRType*)builder.emitSpecializeInst(
3184-
builder.getTypeKind(),
3185-
originalType,
3186-
(UInt)args.getCount(),
3187-
args.getBuffer()));
3182+
3183+
auto concreteType = as<IRType>(builder.emitSpecializeInst(
3184+
builder.getTypeKind(),
3185+
originalType,
3186+
(UInt)args.getCount(),
3187+
args.getBuffer()));
3188+
3189+
auto witnessTableType = innerResult.diffWitness->getFullType();
3190+
auto newWitnessTable = builder.createWitnessTable(witnessTableType, concreteType);
3191+
3192+
// Copy all entries from the old witness table to the new one
3193+
for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
3194+
{
3195+
builder.createWitnessTableEntry(
3196+
newWitnessTable,
3197+
entry->getRequirementKey(),
3198+
entry->getSatisfyingVal());
3199+
}
3200+
31883201
result.diffWitness =
3189-
hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true);
3202+
hoistValueFromGeneric(builder, newWitnessTable, specInst, true);
31903203
}
31913204
return result;
31923205
}

source/slang/slang-ir-clone.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,12 @@ void cloneDecoration(
309309
IRInst* newParent,
310310
IRModule* module)
311311
{
312+
// If the parent is hoistable and is the same as the original decoration's parent,
313+
// skip cloning to avoid duplicating decorations
314+
if (getIROpInfo(newParent->getOp()).isHoistable() &&
315+
oldDecoration->getParent() == newParent)
316+
return;
317+
312318
IRBuilder builder(module);
313319

314320
if (auto first = newParent->getFirstDecorationOrChild())

source/slang/slang-ir-inst-defs.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ INST(GlobalConstant, globalConstant, 0, GLOBAL)
294294

295295
INST(StructKey, key, 0, GLOBAL)
296296
INST(GlobalGenericParam, global_generic_param, 0, GLOBAL)
297-
INST(WitnessTable, witness_table, 0, 0)
297+
INST(WitnessTable, witness_table, 0, HOISTABLE)
298298

299299
INST(IndexedFieldKey, indexedFieldKey, 2, HOISTABLE)
300300

source/slang/slang-ir-insts.h

-2
Original file line numberDiff line numberDiff line change
@@ -2934,8 +2934,6 @@ struct IRWitnessTable : IRInst
29342934

29352935
IRType* getConcreteType() { return (IRType*)getOperand(0); }
29362936

2937-
void setConcreteType(IRType* t) { return setOperand(0, t); }
2938-
29392937
IR_LEAF_ISA(WitnessTable)
29402938
};
29412939

source/slang/slang-ir-link.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
224224
case kIROp_StructKey:
225225
case kIROp_InterfaceRequirementEntry:
226226
case kIROp_GlobalGenericParam:
227-
case kIROp_WitnessTable:
228227
case kIROp_InterfaceType:
229228
return cloneGlobalValue(this, originalValue);
230229

source/slang/slang-lower-to-ir.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -8138,17 +8138,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
81388138
//
81398139
auto irWitnessTableBaseType = lowerType(subContext, superType);
81408140

8141+
// Register a dummy value to avoid infinite recursions.
8142+
// Without this, the call to lowerType() can get into an infinite recursion.
8143+
//
8144+
context->setGlobalValue(
8145+
inheritanceDecl,
8146+
LoweredValInfo::simple(findOuterMostGeneric(subBuilder->getInsertLoc().getParent())));
8147+
8148+
auto irSubType = lowerType(subContext, subType);
8149+
81418150
// Create the IR-level witness table
8142-
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr);
8151+
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType);
81438152

8144-
// Register the value now, rather than later, to avoid any possible infinite recursion.
8153+
// Override with the correct witness-table
81458154
context->setGlobalValue(
81468155
inheritanceDecl,
81478156
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
81488157

8149-
auto irSubType = lowerType(subContext, subType);
8150-
irWitnessTable->setConcreteType(irSubType);
8151-
81528158
// TODO(JS):
81538159
// Should the mangled name take part in obfuscation if enabled?
81548160

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//TEST:SIMPLE(filecheck=CHK):-stage compute -entry computeMain -target hlsl
2+
3+
//CHK: struct DiffPair_1
4+
//CHK-NOT: struct DiffPair_2
5+
6+
RWTexture2D<float> gOutputColor;
7+
8+
struct ShadingFrame : IDifferentiable
9+
{
10+
float3 T;
11+
}
12+
13+
[Differentiable]
14+
float computeRay()
15+
{
16+
float3 dir = 1.f;
17+
return dot(dir, dir);
18+
}
19+
20+
[Differentiable]
21+
float paramRay()
22+
{
23+
DifferentialPair<float> dpDir = fwd_diff(computeRay)();
24+
return dpDir.p;
25+
}
26+
27+
[Shader("compute")]
28+
[NumThreads(1, 1, 1)]
29+
void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
30+
{
31+
DifferentialPair<float> dpColor = fwd_diff(paramRay)();
32+
gOutputColor[0] = dpColor.p;
33+
}

0 commit comments

Comments
 (0)