Skip to content

Commit e27ae04

Browse files
committed
Clean up in autodiff
1 parent 73304dc commit e27ae04

File tree

4 files changed

+146
-122
lines changed

4 files changed

+146
-122
lines changed

source/slang/slang-ir-autodiff.cpp

+127-99
Original file line numberDiff line numberDiff line change
@@ -1861,17 +1861,20 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
18611861
sharedContext->differentiableInterfaceType,
18621862
(IRType*)pairType);
18631863

1864-
// And place it in the synthesized witness table.
1865-
builder->createWitnessTableEntry(
1866-
table,
1867-
sharedContext->differentialAssocTypeStructKey,
1868-
diffDiffPairType);
1869-
builder->createWitnessTableEntry(
1870-
table,
1871-
sharedContext->differentialAssocTypeWitnessStructKey,
1872-
table);
1873-
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
1874-
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
1864+
if (table->getFirstDecorationOrChild() == nullptr)
1865+
{
1866+
// And place it in the synthesized witness table.
1867+
builder->createWitnessTableEntry(
1868+
table,
1869+
sharedContext->differentialAssocTypeStructKey,
1870+
diffDiffPairType);
1871+
builder->createWitnessTableEntry(
1872+
table,
1873+
sharedContext->differentialAssocTypeWitnessStructKey,
1874+
table);
1875+
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
1876+
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
1877+
}
18751878

18761879
bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
18771880

@@ -1943,15 +1946,18 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
19431946
sharedContext->differentiablePtrInterfaceType,
19441947
(IRType*)pairType);
19451948

1946-
// And place it in the synthesized witness table.
1947-
builder->createWitnessTableEntry(
1948-
table,
1949-
sharedContext->differentialAssocRefTypeStructKey,
1950-
diffDiffPairType);
1951-
builder->createWitnessTableEntry(
1952-
table,
1953-
sharedContext->differentialAssocRefTypeWitnessStructKey,
1954-
table);
1949+
if (table->getFirstDecorationOrChild() == nullptr)
1950+
{
1951+
// And place it in the synthesized witness table.
1952+
builder->createWitnessTableEntry(
1953+
table,
1954+
sharedContext->differentialAssocRefTypeStructKey,
1955+
diffDiffPairType);
1956+
builder->createWitnessTableEntry(
1957+
table,
1958+
sharedContext->differentialAssocRefTypeWitnessStructKey,
1959+
table);
1960+
}
19551961
}
19561962

19571963
return table;
@@ -1986,17 +1992,20 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
19861992
sharedContext->differentiableInterfaceType,
19871993
(IRType*)arrayType);
19881994

1989-
// And place it in the synthesized witness table.
1990-
builder->createWitnessTableEntry(
1991-
table,
1992-
sharedContext->differentialAssocTypeStructKey,
1993-
diffArrayType);
1994-
builder->createWitnessTableEntry(
1995-
table,
1996-
sharedContext->differentialAssocTypeWitnessStructKey,
1997-
table);
1998-
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
1999-
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
1995+
if (table->getFirstDecorationOrChild() == nullptr)
1996+
{
1997+
// And place it in the synthesized witness table.
1998+
builder->createWitnessTableEntry(
1999+
table,
2000+
sharedContext->differentialAssocTypeStructKey,
2001+
diffArrayType);
2002+
builder->createWitnessTableEntry(
2003+
table,
2004+
sharedContext->differentialAssocTypeWitnessStructKey,
2005+
table);
2006+
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
2007+
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
2008+
}
20002009

20012010
auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
20022011

@@ -2065,15 +2074,18 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
20652074
sharedContext->differentiablePtrInterfaceType,
20662075
(IRType*)arrayType);
20672076

2068-
// And place it in the synthesized witness table.
2069-
builder->createWitnessTableEntry(
2070-
table,
2071-
sharedContext->differentialAssocRefTypeStructKey,
2072-
diffArrayType);
2073-
builder->createWitnessTableEntry(
2074-
table,
2075-
sharedContext->differentialAssocRefTypeWitnessStructKey,
2076-
table);
2077+
if (table->getFirstDecorationOrChild() == nullptr)
2078+
{
2079+
// And place it in the synthesized witness table.
2080+
builder->createWitnessTableEntry(
2081+
table,
2082+
sharedContext->differentialAssocRefTypeStructKey,
2083+
diffArrayType);
2084+
builder->createWitnessTableEntry(
2085+
table,
2086+
sharedContext->differentialAssocRefTypeWitnessStructKey,
2087+
table);
2088+
}
20772089
}
20782090
else
20792091
{
@@ -2106,17 +2118,20 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
21062118
sharedContext->differentiableInterfaceType,
21072119
(IRType*)inTupleType);
21082120

2109-
// And place it in the synthesized witness table.
2110-
builder->createWitnessTableEntry(
2111-
table,
2112-
sharedContext->differentialAssocTypeStructKey,
2113-
diffTupleType);
2114-
builder->createWitnessTableEntry(
2115-
table,
2116-
sharedContext->differentialAssocTypeWitnessStructKey,
2117-
table);
2118-
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
2119-
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
2121+
if (table->getFirstDecorationOrChild() == nullptr)
2122+
{
2123+
// And place it in the synthesized witness table.
2124+
builder->createWitnessTableEntry(
2125+
table,
2126+
sharedContext->differentialAssocTypeStructKey,
2127+
diffTupleType);
2128+
builder->createWitnessTableEntry(
2129+
table,
2130+
sharedContext->differentialAssocTypeWitnessStructKey,
2131+
table);
2132+
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
2133+
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
2134+
}
21202135

21212136
// Fill in differential method implementations.
21222137
{
@@ -2219,15 +2234,18 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
22192234
sharedContext->differentiablePtrInterfaceType,
22202235
(IRType*)inTupleType);
22212236

2222-
// And place it in the synthesized witness table.
2223-
builder->createWitnessTableEntry(
2224-
table,
2225-
sharedContext->differentialAssocRefTypeStructKey,
2226-
diffTupleType);
2227-
builder->createWitnessTableEntry(
2228-
table,
2229-
sharedContext->differentialAssocRefTypeWitnessStructKey,
2230-
table);
2237+
if (table->getFirstDecorationOrChild() == nullptr)
2238+
{
2239+
// And place it in the synthesized witness table.
2240+
builder->createWitnessTableEntry(
2241+
table,
2242+
sharedContext->differentialAssocRefTypeStructKey,
2243+
diffTupleType);
2244+
builder->createWitnessTableEntry(
2245+
table,
2246+
sharedContext->differentialAssocRefTypeWitnessStructKey,
2247+
table);
2248+
}
22312249
}
22322250

22332251
return table;
@@ -3078,39 +3096,45 @@ struct AutoDiffPass : public InstPassBase
30783096
builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType);
30793097
result.diffWitness = origTypeIsDiffWitness;
30803098

3081-
builder.createWitnessTableEntry(
3082-
origTypeIsDiffWitness,
3083-
autodiffContext->differentialAssocTypeStructKey,
3084-
diffType);
3085-
builder.createWitnessTableEntry(
3086-
origTypeIsDiffWitness,
3087-
autodiffContext->differentialAssocTypeWitnessStructKey,
3088-
diffTypeIsDiffWitness);
3089-
builder.createWitnessTableEntry(
3090-
origTypeIsDiffWitness,
3091-
autodiffContext->zeroMethodStructKey,
3092-
zeroMethod);
3093-
builder.createWitnessTableEntry(
3094-
origTypeIsDiffWitness,
3095-
autodiffContext->addMethodStructKey,
3096-
addMethod);
3097-
3098-
builder.createWitnessTableEntry(
3099-
diffTypeIsDiffWitness,
3100-
autodiffContext->differentialAssocTypeStructKey,
3101-
diffType);
3102-
builder.createWitnessTableEntry(
3103-
diffTypeIsDiffWitness,
3104-
autodiffContext->differentialAssocTypeWitnessStructKey,
3105-
diffTypeIsDiffWitness);
3106-
builder.createWitnessTableEntry(
3107-
diffTypeIsDiffWitness,
3108-
autodiffContext->zeroMethodStructKey,
3109-
zeroMethod);
3110-
builder.createWitnessTableEntry(
3111-
diffTypeIsDiffWitness,
3112-
autodiffContext->addMethodStructKey,
3113-
addMethod);
3099+
if (origTypeIsDiffWitness->getFirstDecorationOrChild() == nullptr)
3100+
{
3101+
builder.createWitnessTableEntry(
3102+
origTypeIsDiffWitness,
3103+
autodiffContext->differentialAssocTypeStructKey,
3104+
diffType);
3105+
builder.createWitnessTableEntry(
3106+
origTypeIsDiffWitness,
3107+
autodiffContext->differentialAssocTypeWitnessStructKey,
3108+
diffTypeIsDiffWitness);
3109+
builder.createWitnessTableEntry(
3110+
origTypeIsDiffWitness,
3111+
autodiffContext->zeroMethodStructKey,
3112+
zeroMethod);
3113+
builder.createWitnessTableEntry(
3114+
origTypeIsDiffWitness,
3115+
autodiffContext->addMethodStructKey,
3116+
addMethod);
3117+
}
3118+
3119+
if (diffTypeIsDiffWitness->getFirstDecorationOrChild() == nullptr)
3120+
{
3121+
builder.createWitnessTableEntry(
3122+
diffTypeIsDiffWitness,
3123+
autodiffContext->differentialAssocTypeStructKey,
3124+
diffType);
3125+
builder.createWitnessTableEntry(
3126+
diffTypeIsDiffWitness,
3127+
autodiffContext->differentialAssocTypeWitnessStructKey,
3128+
diffTypeIsDiffWitness);
3129+
builder.createWitnessTableEntry(
3130+
diffTypeIsDiffWitness,
3131+
autodiffContext->zeroMethodStructKey,
3132+
zeroMethod);
3133+
builder.createWitnessTableEntry(
3134+
diffTypeIsDiffWitness,
3135+
autodiffContext->addMethodStructKey,
3136+
addMethod);
3137+
}
31143138
return result;
31153139
}
31163140

@@ -3178,6 +3202,7 @@ struct AutoDiffPass : public InstPassBase
31783202
for (auto param : genType->getParams())
31793203
args.add(param);
31803204

3205+
// Create a new WitnessTable with a different concreteType.
31813206
auto concreteType = as<IRType>(builder.emitSpecializeInst(
31823207
builder.getTypeKind(),
31833208
originalType,
@@ -3187,13 +3212,16 @@ struct AutoDiffPass : public InstPassBase
31873212
auto witnessTableType = innerResult.diffWitness->getFullType();
31883213
auto newWitnessTable = builder.createWitnessTable(witnessTableType, concreteType);
31893214

3190-
// Copy all entries from the old witness table to the new one
3191-
for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
3215+
if (newWitnessTable->getFirstDecorationOrChild() == nullptr)
31923216
{
3193-
builder.createWitnessTableEntry(
3194-
newWitnessTable,
3195-
entry->getRequirementKey(),
3196-
entry->getSatisfyingVal());
3217+
builder.setInsertInto(newWitnessTable);
3218+
for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
3219+
{
3220+
builder.createWitnessTableEntry(
3221+
newWitnessTable,
3222+
entry->getRequirementKey(),
3223+
entry->getSatisfyingVal());
3224+
}
31973225
}
31983226

31993227
result.diffWitness =

source/slang/slang-ir-link.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ IRWitnessTable* cloneWitnessTableImpl(
732732
clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
733733
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
734734
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
735+
SLANG_RELEASE_ASSERT(clonedTable->getFirstDecorationOrChild() == nullptr);
735736
}
736737
else
737738
{

source/slang/slang-ir-specialize.cpp

+3-8
Original file line numberDiff line numberDiff line change
@@ -3123,18 +3123,13 @@ IRInst* specializeGenericImpl(
31233123
//
31243124
IRInstList<IRInst> ordinaryInsts = bb->getOrdinaryInsts();
31253125

3126-
// After IRWitnessTable became Hoistable, they are removed and inserted back by
3126+
// After IRWitnessTable became Hoistable, they are removed and insered back by
31273127
// `addHoistableInst()`. But when they are re-inserted, the order it appears in the block is
3128-
// changed. We need to change the order in a way that the dependancy is resolved.
3129-
//
3130-
// The dependency cannot be resolve in `addHoistableInst()`, because IRWitnessTable doesn't
3131-
// have IRWitnessTableEntry yet while in the function.
3132-
//
3133-
// When IRWitnessTable refers to IRSpecialize, as an example, IRSpecialize must be
3128+
// changed. When IRWitnessTable refers to IRSpecialize, as an example, IRSpecialize must be
31343129
// cloned before the cloning of IRWitnessTable. It is because the operands are assumed to be
31353130
// cloned before the cloning of IRInst.
31363131
//
3137-
// Similarly, there can be dependencies between an IRWitnessTable and another IRWitnessTable.
3132+
// We need to resolve the dependency problem by changing the order of them.
31383133
//
31393134
List<IRInst*> insts;
31403135
int instCount = 0;

source/slang/slang-lower-to-ir.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -8016,10 +8016,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
80168016
irWitnessTableBaseType,
80178017
irWitnessTable->getConcreteType());
80188018

8019-
// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return a
8020-
// pre-existing one. We need to avoid adding the same decorations/children
8021-
// when the IRWitnessTable already has them.
8022-
//
8019+
// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return an
8020+
// IRWitnessTable that already has decorations/children. We need to avoid
8021+
// adding them more than once.
80238022
if (irSatisfyingWitnessTable->getFirstDecorationOrChild() == nullptr)
80248023
{
80258024
auto mangledName = getMangledNameForConformanceWitness(
@@ -8041,9 +8040,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
80418040
astReqWitnessTable,
80428041
irSatisfyingWitnessTable,
80438042
mapASTToIRWitnessTable);
8044-
8045-
irSatisfyingWitnessTable->moveToEnd();
80468043
}
8044+
8045+
irSatisfyingWitnessTable->moveToEnd();
80478046
}
80488047
irSatisfyingVal = irSatisfyingWitnessTable;
80498048
}
@@ -8164,17 +8163,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
81648163
// Create the IR-level witness table
81658164
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType);
81668165

8167-
// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return a
8168-
// pre-existing one. We need to avoid adding the same decorations/children
8169-
// when the IRWitnessTable already has them.
8166+
// Override with the correct witness-table
8167+
context->setGlobalValue(
8168+
inheritanceDecl,
8169+
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
8170+
8171+
// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return an
8172+
// IRWitnessTable that already has decorations/children. We need to avoid adding them
8173+
// more than once.
81708174
//
81718175
if (irWitnessTable->getFirstDecorationOrChild() == nullptr)
81728176
{
8173-
// Override with the correct witness-table
8174-
context->setGlobalValue(
8175-
inheritanceDecl,
8176-
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));
8177-
81788177
// TODO(JS):
81798178
// Should the mangled name take part in obfuscation if enabled?
81808179

@@ -8215,9 +8214,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
82158214
irWitnessTable,
82168215
mapASTToIRWitnessTable);
82178216
}
8218-
irWitnessTable->moveToEnd();
82198217
}
82208218

8219+
irWitnessTable->moveToEnd();
8220+
82218221
return LoweredValInfo::simple(
82228222
finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric));
82238223
}

0 commit comments

Comments
 (0)