Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UseGraph::replace #6395

Merged
merged 11 commits into from
Feb 25, 2025
4 changes: 2 additions & 2 deletions source/slang/slang-ir-autodiff-fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)

for (auto param : params)
{
auto ptrType = as<IRPtrTypeBase>(param->getDataType());
auto ptrType = asRelevantPtrType(param->getDataType());
auto tempVar = builder.emitVar(ptrType->getValueType());
param->replaceUsesWith(tempVar);
mapParamToTempVar[param] = tempVar;
Expand Down Expand Up @@ -2245,7 +2245,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(
builder->emitDifferentialPairGetPrimal(diffPairParam),
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
}
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
else if (auto pairPtrType = asRelevantPtrType(diffPairType))
{
auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType());
// Make a local copy of the parameter for primal and diff parts.
Expand Down
88 changes: 38 additions & 50 deletions source/slang/slang-ir-autodiff-primal-hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,7 @@ IRVar* emitIndexedLocalVar(
SourceLoc location)
{
// Cannot store pointers. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
SLANG_RELEASE_ASSERT(!asRelevantPtrType(baseType));

// Cannot store types. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));
Expand Down Expand Up @@ -1326,7 +1326,11 @@ static int getInstRegionNestLevel(

struct UseChain
{
// The chain of uses from the base use to the relevant use.
// However, this is stored in reverse order (so that the last use is the 'base use')
//
List<IRUse*> chain;

static List<UseChain> from(
IRUse* baseUse,
Func<bool, IRUse*> isRelevantUse,
Expand Down Expand Up @@ -1366,41 +1370,20 @@ struct UseChain
return result;
}

void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst)
// This function only replaces the inner links, not the base use.
void replaceInnerLinks(IROutOfOrderCloneContext* ctx, IRBuilder* builder)
{
SLANG_ASSERT(chain.getCount() > 0);

// Simple case: if there is only one use, then we can just replace it.
if (chain.getCount() == 1)
{
builder->replaceOperand(chain.getLast(), inst);
chain.clear();
return;
}

// Pop the last use, which is the base use that needs to be replaced.
auto baseUse = chain.getLast();
chain.removeLast();
const UIndex count = chain.getCount();

// Ensure that replacement inst is set as mapping for the baseUse.
ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst;

IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);

chain.reverse();
chain.removeLast();

// Clone the rest of the chain.
for (auto& use : chain)
// Process the chain in reverse order (excluding the first and last elements).
// That is, iterate from count - 2 down to 1 (inclusive).
for (int i = ((int)count) - 2; i >= 1; i--)
{
ctx->cloneInstOutOfOrder(&chainBuilder, use->get());
IRUse* use = chain[i];
ctx->cloneInstOutOfOrder(builder, use->get());
}

// We won't actually replace the final use, because if there are multiple chains
// it can cause problems. The parent UseGraph will handle that.

chain.clear();
}

IRInst* getUser() const
Expand All @@ -1417,6 +1400,14 @@ struct UseGraph
//
OrderedDictionary<IRUse*, List<UseChain>> chainSets;

// Create a UseGraph from a base inst.
//
// `isRelevantUse` is a predicate that determines if a use is relevant. Traversal will stop at
// this use, and all chains to this use will be grouped together.
//
// `passthroughInst` is a predicate that determines if an inst should be looked through
// for uses.
//
static UseGraph from(
IRInst* baseInst,
Func<bool, IRUse*> isRelevantUse,
Expand Down Expand Up @@ -1445,36 +1436,33 @@ struct UseGraph
return result;
}

void replace(IRBuilder* builder, IRUse* use, IRInst* inst)
void replace(IRBuilder* builder, IRUse* relevantUse, IRInst* inst)
{
// Since we may have common nodes, we will use an out-of-order cloning context
// that can retroactively correct the uses as needed.
//
IROutOfOrderCloneContext ctx;
List<UseChain> chains = chainSets[use];
for (auto chain : chains)
{
chain.replace(&ctx, builder, inst);
}
List<UseChain> chains = chainSets[relevantUse];

if (!isTrivial())
// Link the first use of each chain to inst.
for (auto& chain : chains)
ctx.cloneEnv.mapOldValToNew[chain.chain.getLast()->get()] = inst;

// Process the inner links of each chain using the replacement.
for (auto& chain : chains)
{
builder->setInsertBefore(use->getUser());
auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get());
IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);

// Replace the base use.
builder->replaceOperand(use, lastInstInChain);
chain.replaceInnerLinks(&ctx, builder);
}
}

bool isTrivial()
{
// We're trivial if there's only one chain, and it has only one use.
if (chainSets.getCount() != 1)
return false;
// Finally, replace the relevant use (i.e, "final use") with the new replacement inst.
builder->setInsertBefore(relevantUse->getUser());
auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, relevantUse->get());

auto& chain = chainSets.getFirst().value;
return chain.getCount() == 1;
// Replace the base use.
builder->replaceOperand(relevantUse, lastInstInChain);
}

List<IRUse*> getUniqueUses() const
Expand Down Expand Up @@ -1668,7 +1656,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return true;
}
else if (
as<IRPtrTypeBase>(instToStore->getDataType()) &&
asRelevantPtrType(instToStore->getDataType()) &&
!isDifferentialOrRecomputeBlock(defBlock))
{
return true;
Expand Down
6 changes: 3 additions & 3 deletions source/slang/slang-ir-autodiff-rev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(
auto diffPairType = tryGetDiffPairType(builder, paramType);
if (diffPairType)
{
if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
if (!asRelevantPtrType(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
return builder->getInOutType(diffPairType);
return diffPairType;
}
Expand Down Expand Up @@ -514,7 +514,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
{
// As long as the primal parameter is not an out or constref type,
// we need to fetch the primal value from the parameter.
if (as<IRPtrTypeBase>(propagateParamType))
if (asRelevantPtrType(propagateParamType))
{
primalArg = builder.emitLoad(param);
}
Expand Down Expand Up @@ -544,7 +544,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
}
else
{
auto primalPtrType = as<IRPtrTypeBase>(primalParamType);
auto primalPtrType = asRelevantPtrType(primalParamType);
SLANG_RELEASE_ASSERT(primalPtrType);
auto primalValueType = primalPtrType->getValueType();
auto var = builder.emitVar(primalValueType);
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-ir-autodiff-transcriber-base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
if (isNoDiffType(origType))
return nullptr;

if (auto ptrType = as<IRPtrTypeBase>(origType))
if (auto ptrType = asRelevantPtrType(origType))
return builder->getPtrType(
origType->getOp(),
differentiateType(builder, ptrType->getValueType()));
Expand Down Expand Up @@ -556,7 +556,7 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
if (isNoDiffType(originalType))
return nullptr;

if (auto origPtrType = as<IRPtrTypeBase>(originalType))
if (auto origPtrType = asRelevantPtrType(originalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
return builder->getPtrType(originalType->getOp(), diffPairValueType);
Expand Down
7 changes: 3 additions & 4 deletions source/slang/slang-ir-autodiff-transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ struct DiffTransposePass
if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
{
if (auto ptrPrimalType =
as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
asRelevantPtrType(tryGetPrimalTypeFromDiffInst(varInst)))
{
varInst->insertAtEnd(firstRevDiffBlock);

Expand Down Expand Up @@ -1119,7 +1119,7 @@ struct DiffTransposePass

auto getDiffPairType = [](IRType* type)
{
if (auto ptrType = as<IRPtrTypeBase>(type))
if (auto ptrType = asRelevantPtrType(type))
type = ptrType->getValueType();
return as<IRDifferentialPairType>(type);
};
Expand Down Expand Up @@ -1168,7 +1168,7 @@ struct DiffTransposePass
argRequiresLoad.add(false);
writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
}
else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
else if (!asRelevantPtrType(arg->getDataType()) && getDiffPairType(arg->getDataType()))
{
// Normal differentiable input parameter will become an inout DiffPair parameter
// in the propagate func. The split logic has already prepared the initial value
Expand Down Expand Up @@ -1241,7 +1241,6 @@ struct DiffTransposePass
argRequiresLoad.add(false);
}


auto revFnType =
this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType(
builder,
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-ir-autodiff-unzip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ bool isIntermediateContextType(IRInst* type)
case kIROp_Specialize:
return isIntermediateContextType(as<IRSpecialize>(type)->getBase());
default:
if (as<IRPtrTypeBase>(type))
return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
if (auto ptrType = asRelevantPtrType(type))
return isIntermediateContextType(ptrType->getValueType());
return false;
}
}
Expand Down
7 changes: 3 additions & 4 deletions source/slang/slang-ir-autodiff-unzip.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@ struct DiffUnzipPass
primalParam = primalParam->getNextParam())
{
auto type = primalParam->getFullType();
if (auto ptrType = as<IRPtrTypeBase>(type))
if (auto ptrType = asRelevantPtrType(type))
{
type = ptrType->getValueType();
}
if (auto pairType = as<IRDifferentialPairType>(type))
{
IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType);
if (as<IRPtrTypeBase>(primalParam->getFullType()))
diffType =
builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType);
if (auto ptrType = asRelevantPtrType(primalParam->getFullType()))
diffType = builder->getPtrType(ptrType->getOp(), (IRType*)diffType);
auto primalRef = builder->emitPrimalParamRef(primalParam);
auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam);
builder->markInstAsDifferential(diffRef, pairType->getValueType());
Expand Down
10 changes: 5 additions & 5 deletions source/slang/slang-ir-autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType)

paramType = attrType->getBaseType();
}
else if (auto ptrType = as<IRPtrTypeBase>(paramType))
else if (auto ptrType = asRelevantPtrType(paramType))
{
paramType = ptrType->getValueType();
}
Expand Down Expand Up @@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
IRStructKey* key)
{
IRInst* pairType = nullptr;
if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
if (auto basePtrType = asRelevantPtrType(baseInst->getDataType()))
{
auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType());

Expand All @@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
baseInst,
key));
}
else if (auto ptrType = as<IRPtrTypeBase>(pairType))
else if (auto ptrType = asRelevantPtrType(pairType))
{
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
Expand Down Expand Up @@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
baseInst,
key));
}
else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
else if (auto genericPtrType = asRelevantPtrType(genericType))
{
if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
{
Expand Down Expand Up @@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
IRBuilder* builder,
IRInst* primalType)
{
if (auto ptrType = as<IRPtrTypeBase>(primalType))
if (auto ptrType = asRelevantPtrType(primalType))
return builder->getPtrType(
primalType->getOp(),
differentiateType(builder, ptrType->getValueType()));
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ inline bool isRelevantDifferentialPair(IRType* type)
{
return true;
}
else if (auto argPtrType = as<IRPtrTypeBase>(type))
else if (auto argPtrType = asRelevantPtrType(type))
{
if (as<IRDifferentialPairType>(argPtrType->getValueType()))
{
Expand Down
12 changes: 11 additions & 1 deletion source/slang/slang-ir-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,14 +1528,24 @@ bool isOne(IRInst* inst)
}
}

IRPtrTypeBase* asRelevantPtrType(IRInst* inst)
{
if (auto ptrType = as<IRPtrTypeBase>(inst))
{
if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
return ptrType;
}
return nullptr;
}

IRPtrTypeBase* isMutablePointerType(IRInst* inst)
{
switch (inst->getOp())
{
case kIROp_ConstRefType:
return nullptr;
default:
return as<IRPtrTypeBase>(inst);
return asRelevantPtrType(inst);
}
}

Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-ir-util.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ bool isZero(IRInst* inst);

bool isOne(IRInst* inst);

// Casts inst to IRPtrTypeBase, excluding UserPointer address space.
IRPtrTypeBase* asRelevantPtrType(IRInst* inst);

// Returns the pointer type if it is pointer type that is not a const ref or a user pointer.
IRPtrTypeBase* isMutablePointerType(IRInst* inst);

void initializeScratchData(IRInst* inst);
Expand Down
Loading
Loading