Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make IRWitnessTable HOISTABLE #6417

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up in autodiff
  • Loading branch information
jkwak-work committed Mar 20, 2025
commit 2b17660c0dc55469731192605ca5c03d660ddcde
226 changes: 127 additions & 99 deletions source/slang/slang-ir-autodiff.cpp
Original file line number Diff line number Diff line change
@@ -1862,17 +1862,20 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
sharedContext->differentiableInterfaceType,
(IRType*)pairType);

// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeStructKey,
diffDiffPairType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeWitnessStructKey,
table);
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
if (table->getFirstDecorationOrChild() == nullptr)
{
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeStructKey,
diffDiffPairType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeWitnessStructKey,
table);
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
}

bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;

@@ -1944,15 +1947,18 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)pairType);

// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeStructKey,
diffDiffPairType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeWitnessStructKey,
table);
if (table->getFirstDecorationOrChild() == nullptr)
{
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeStructKey,
diffDiffPairType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeWitnessStructKey,
table);
}
}

return table;
@@ -1987,17 +1993,20 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
sharedContext->differentiableInterfaceType,
(IRType*)arrayType);

// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeStructKey,
diffArrayType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeWitnessStructKey,
table);
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
if (table->getFirstDecorationOrChild() == nullptr)
{
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeStructKey,
diffArrayType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeWitnessStructKey,
table);
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
}

auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();

@@ -2066,15 +2075,18 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)arrayType);

// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeStructKey,
diffArrayType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeWitnessStructKey,
table);
if (table->getFirstDecorationOrChild() == nullptr)
{
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeStructKey,
diffArrayType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeWitnessStructKey,
table);
}
}
else
{
@@ -2105,17 +2117,20 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
sharedContext->differentiableInterfaceType,
(IRType*)inTupleType);

// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeStructKey,
diffTupleType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeWitnessStructKey,
table);
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
if (table->getFirstDecorationOrChild() == nullptr)
{
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeStructKey,
diffTupleType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocTypeWitnessStructKey,
table);
builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
}

// Fill in differential method implementations.
{
@@ -2219,15 +2234,18 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
sharedContext->differentiablePtrInterfaceType,
(IRType*)inTupleType);

// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeStructKey,
diffTupleType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeWitnessStructKey,
table);
if (table->getFirstDecorationOrChild() == nullptr)
{
// And place it in the synthesized witness table.
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeStructKey,
diffTupleType);
builder->createWitnessTableEntry(
table,
sharedContext->differentialAssocRefTypeWitnessStructKey,
table);
}
}

return table;
@@ -3078,39 +3096,45 @@ struct AutoDiffPass : public InstPassBase
builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType);
result.diffWitness = origTypeIsDiffWitness;

builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->differentialAssocTypeStructKey,
diffType);
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->differentialAssocTypeWitnessStructKey,
diffTypeIsDiffWitness);
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->zeroMethodStructKey,
zeroMethod);
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->addMethodStructKey,
addMethod);

builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->differentialAssocTypeStructKey,
diffType);
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->differentialAssocTypeWitnessStructKey,
diffTypeIsDiffWitness);
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->zeroMethodStructKey,
zeroMethod);
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->addMethodStructKey,
addMethod);
if (origTypeIsDiffWitness->getFirstDecorationOrChild() == nullptr)
{
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->differentialAssocTypeStructKey,
diffType);
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->differentialAssocTypeWitnessStructKey,
diffTypeIsDiffWitness);
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->zeroMethodStructKey,
zeroMethod);
builder.createWitnessTableEntry(
origTypeIsDiffWitness,
autodiffContext->addMethodStructKey,
addMethod);
}

if (diffTypeIsDiffWitness->getFirstDecorationOrChild() == nullptr)
{
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->differentialAssocTypeStructKey,
diffType);
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->differentialAssocTypeWitnessStructKey,
diffTypeIsDiffWitness);
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->zeroMethodStructKey,
zeroMethod);
builder.createWitnessTableEntry(
diffTypeIsDiffWitness,
autodiffContext->addMethodStructKey,
addMethod);
}
return result;
}

@@ -3178,6 +3202,7 @@ struct AutoDiffPass : public InstPassBase
for (auto param : genType->getParams())
args.add(param);

// Create a new WitnessTable with a different concreteType.
auto concreteType = as<IRType>(builder.emitSpecializeInst(
builder.getTypeKind(),
originalType,
@@ -3187,13 +3212,16 @@ struct AutoDiffPass : public InstPassBase
auto witnessTableType = innerResult.diffWitness->getFullType();
auto newWitnessTable = builder.createWitnessTable(witnessTableType, concreteType);

// Copy all entries from the old witness table to the new one
for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
if (newWitnessTable->getFirstDecorationOrChild() == nullptr)
{
builder.createWitnessTableEntry(
newWitnessTable,
entry->getRequirementKey(),
entry->getSatisfyingVal());
builder.setInsertInto(newWitnessTable);
for (auto entry : as<IRWitnessTable>(innerResult.diffWitness)->getEntries())
{
builder.createWitnessTableEntry(
newWitnessTable,
entry->getRequirementKey(),
entry->getSatisfyingVal());
}
}

result.diffWitness =
1 change: 1 addition & 0 deletions source/slang/slang-ir-link.cpp
Original file line number Diff line number Diff line change
@@ -732,6 +732,7 @@ IRWitnessTable* cloneWitnessTableImpl(
clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
SLANG_RELEASE_ASSERT(clonedTable->getFirstDecorationOrChild() == nullptr);
}
else
{
11 changes: 3 additions & 8 deletions source/slang/slang-ir-specialize.cpp
Original file line number Diff line number Diff line change
@@ -3123,18 +3123,13 @@ IRInst* specializeGenericImpl(
//
IRInstList<IRInst> ordinaryInsts = bb->getOrdinaryInsts();

// After IRWitnessTable became Hoistable, they are removed and inserted back by
// After IRWitnessTable became Hoistable, they are removed and insered back by
// `addHoistableInst()`. But when they are re-inserted, the order it appears in the block is
// changed. We need to change the order in a way that the dependancy is resolved.
//
// The dependency cannot be resolve in `addHoistableInst()`, because IRWitnessTable doesn't
// have IRWitnessTableEntry yet while in the function.
//
// When IRWitnessTable refers to IRSpecialize, as an example, IRSpecialize must be
// changed. When IRWitnessTable refers to IRSpecialize, as an example, IRSpecialize must be
// cloned before the cloning of IRWitnessTable. It is because the operands are assumed to be
// cloned before the cloning of IRInst.
//
// Similarly, there can be dependencies between an IRWitnessTable and another IRWitnessTable.
// We need to resolve the dependency problem by changing the order of them.
//
List<IRInst*> insts;
int instCount = 0;
30 changes: 15 additions & 15 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
@@ -8046,10 +8046,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irWitnessTableBaseType,
irWitnessTable->getConcreteType());

// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return a
// pre-existing one. We need to avoid adding the same decorations/children
// when the IRWitnessTable already has them.
//
// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return an
// IRWitnessTable that already has decorations/children. We need to avoid
// adding them more than once.
if (irSatisfyingWitnessTable->getFirstDecorationOrChild() == nullptr)
{
auto mangledName = getMangledNameForConformanceWitness(
@@ -8071,9 +8070,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
astReqWitnessTable,
irSatisfyingWitnessTable,
mapASTToIRWitnessTable);

irSatisfyingWitnessTable->moveToEnd();
}

irSatisfyingWitnessTable->moveToEnd();
}
irSatisfyingVal = irSatisfyingWitnessTable;
}
@@ -8194,17 +8193,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Create the IR-level witness table
auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, irSubType);

// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return a
// pre-existing one. We need to avoid adding the same decorations/children
// when the IRWitnessTable already has them.
// Override with the correct witness-table
context->setGlobalValue(
inheritanceDecl,
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));

// Since IRWitnessTable is Hoistable, `createWitnessTable()` may return an
// IRWitnessTable that already has decorations/children. We need to avoid adding them
// more than once.
//
if (irWitnessTable->getFirstDecorationOrChild() == nullptr)
{
// Override with the correct witness-table
context->setGlobalValue(
inheritanceDecl,
LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable)));

// TODO(JS):
// Should the mangled name take part in obfuscation if enabled?

@@ -8246,9 +8245,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irWitnessTable,
mapASTToIRWitnessTable);
}
irWitnessTable->moveToEnd();
}

irWitnessTable->moveToEnd();

return LoweredValInfo::simple(
finishOuterGenerics(subBuilder, irWitnessTable, outerGeneric));
}