Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug with hoisting 'IRVar' insts that are used outside the loop #6446

Merged
211 changes: 181 additions & 30 deletions source/slang/slang-ir-autodiff-primal-hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,142 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
return recomputeBlockMap;
}

// Checks if list A is a subset of list B by comparing their primal count parameters.
//
// Parameters:
// indicesA - First list of IndexTrackingInfo to compare
// indicesB - Second list of IndexTrackingInfo to compare
//
// Returns:
// true if all indices in indicesA are present in indicesB, false otherwise
//
bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
{
if (indicesA.getCount() > indicesB.getCount())
return false;

for (Index ii = 0; ii < indicesA.getCount(); ii++)
{
if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
return false;
}

return true;
}

bool canInstBeStored(IRInst* inst)
{
// Cannot store insts whose value is a type or a witness table, or a function.
// These insts get lowered to target-specific logic, and cannot be
// stored into variables or context structs as normal values.
//
if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
!inst->getDataType())
return false;

return true;
}

// This is a helper that converts insts in a loop condition block into two if necessary,
// then replaces all uses 'outside' the loop region with the new insts. This is because
// insts in loop condition blocks can be used in two distinct regions (the loop body, and
// after the loop).
//
// We'll use CheckpointObject for the splitting, which is allowed on any value-typed inst.
//
void splitLoopConditionBlockInsts(
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
{
// RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);

// Collect primal loop condition blocks, and map differential blocks to their primal blocks.
List<IRBlock*> loopConditionBlocks;
Dictionary<IRBlock*, IRBlock*> diffBlockMap;
for (auto block : func->getBlocks())
{
if (auto loop = as<IRLoop>(block->getTerminator()))
{
auto loopConditionBlock = getLoopConditionBlock(loop);
if (isDifferentialBlock(loopConditionBlock))
{
auto diffDecor = loopConditionBlock->findDecoration<IRDifferentialInstDecoration>();
diffBlockMap[cast<IRBlock>(diffDecor->getPrimalInst())] = loopConditionBlock;
}
else
loopConditionBlocks.add(loopConditionBlock);
}
}

// For each loop condition block, split the insts that are used in both the loop body and
// after the loop.
// Use the dominator tree to find uses of insts outside the loop body
//
// Essentially we want to split the uses dominated by the true block and the false block of the
// condition.
//
IRBuilder builder(func->getModule());


List<IRUse*> loopUses;
List<IRUse*> afterLoopUses;

for (auto condBlock : loopConditionBlocks)
{
// For each inst in the primal condition block, check if it has uses inside the loop body
// as well as outside of it. (Use the indexedBlockInfo to perform the teets)
//
for (auto inst = condBlock->getFirstInst(); inst; inst = inst->getNextInst())
{
// Skip terminators and insts that can't be stored
if (as<IRTerminatorInst>(inst) || !canInstBeStored(inst))
continue;
// Shouldn't see any vars.
SLANG_ASSERT(!as<IRVar>(inst));

// Get the indices for the condition block
auto& condBlockIndices = indexedBlockInfo[condBlock];

loopUses.clear();
afterLoopUses.clear();

// Check all uses of this inst
for (auto use = inst->firstUse; use; use = use->nextUse)
{
auto userBlock = getBlock(use->getUser());
auto& userBlockIndices = indexedBlockInfo[userBlock];

// If all of the condBlock's indices are a subset of the userBlock's indices,
// then the userBlock is inside the loop.
//
bool isInLoop = areIndicesSubsetOf(condBlockIndices, userBlockIndices);

if (isInLoop)
loopUses.add(use);
else
afterLoopUses.add(use);
}

// If inst has uses both inside and after the loop, create a copy for after-loop uses
if (loopUses.getCount() > 0 && afterLoopUses.getCount() > 0)
{
setInsertAfterOrdinaryInst(&builder, inst);
auto copy = builder.emitCheckpointObject(inst);

// Copy source location so that checkpoint reporting is accurate
copy->sourceLoc = inst->sourceLoc;

// Replace after-loop uses with the copy
for (auto use : afterLoopUses)
{
builder.replaceOperand(use, copy);
}
}
}
}
}

RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
Expand Down Expand Up @@ -1297,20 +1433,6 @@ bool areIndicesEqual(
return true;
}

bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
{
if (indicesA.getCount() > indicesB.getCount())
return false;

for (Index ii = 0; ii < indicesA.getCount(); ii++)
{
if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
return false;
}

return true;
}

static int getInstRegionNestLevel(
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
IRBlock* defBlock,
Expand Down Expand Up @@ -1510,21 +1632,6 @@ static List<IndexTrackingInfo> maybeTrimIndices(
return result;
}

bool canInstBeStored(IRInst* inst)
{
// Cannot store insts whose value is a type or a witness table, or a function.
// These insts get lowered to target-specific logic, and cannot be
// stored into variables or context structs as normal values.
//
if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
!inst->getDataType())
return false;

return true;
}


/// Legalizes all accesses to primal insts from recompute and diff blocks.
///
RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
Expand Down Expand Up @@ -2104,6 +2211,39 @@ void buildIndexedBlocks(
}
}

// This function simply turns all CheckpointObject insts into a 'no-op'.
// i.e. simply replaces all uses of CheckpointObject with the original value.
//
// This operation is 'correct' because if CheckpointObject's operand is visible
// in a block, then it is visible in all dominated blocks.
//
void lowerCheckpointObjectInsts(IRGlobalValueWithCode* func)
{
// For each block in the function
for (auto block : func->getBlocks())
{
// For each instruction in the block
for (auto inst = block->getFirstInst(); inst;)
{
// Get next inst before potentially removing current one
auto nextInst = inst->getNextInst();

// Check if this is a CheckpointObject instruction
if (auto copyInst = as<IRCheckpointObject>(inst))
{
// Replace all uses of the copy with the original value
auto originalVal = copyInst->getVal();
copyInst->replaceUsesWith(originalVal);

// Remove the now unused copy instruction
inst->removeAndDeallocate();
}

inst = nextInst;
}
}
}

// For each primal inst that is used in reverse blocks, decide if we should recompute or store
// its value, then make them accessible in reverse blocks based the decision.
//
Expand All @@ -2117,6 +2257,9 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
buildIndexedBlocks(indexedBlockInfo, func);

// Split loop condition insts into two if necessary.
splitLoopConditionBlockInsts(func, indexedBlockInfo);

// Create recompute blocks for each region following the same control flow structure
// as in primal code.
//
Expand All @@ -2136,7 +2279,12 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
// necessary load/store logic.
//
return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
auto hoistedPrimalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);

// Lower CheckpointObject insts to a no-op.
lowerCheckpointObjectInsts(func);

return hoistedPrimalsInfo;
}

void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
Expand Down Expand Up @@ -2312,6 +2460,9 @@ static bool shouldStoreInst(IRInst* inst)

break;
}
case kIROp_CheckpointObject:
// Special inst for when a value must be stored.
return true;
default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ INST(BitNot, bitnot, 1, 0)

INST(Select, select, 3, 0)

INST(CheckpointObject, checkpointObj, 1, 0)

INST(GetStringHash, getStringHash, 1, 0)

INST(WaveGetActiveMask, waveGetActiveMask, 0, 0)
Expand Down
18 changes: 18 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,22 @@ struct IRDiscard : IRTerminatorInst
{
};

// Used for representing a distinct copy of an object.
// This will get lowered into a no-op in the backend,
// but is useful for IR transformations that need to consider
// different uses of an inst separately.
//
// For example, when we hoist primal insts out of a loop,
// we need to make distinct copies of the inst for its uses
// within the loop body and outside of it.
//
struct IRCheckpointObject : IRInst
{
IR_LEAF_ISA(CheckpointObject);

IRInst* getVal() { return getOperand(0); }
};

// Signals that this point in the code should be unreachable.
// We can/should emit a dataflow error if we can ever determine
// that a block ending in one of these can actually be
Expand Down Expand Up @@ -4408,6 +4424,8 @@ struct IRBuilder

IRInst* emitDiscard();

IRInst* emitCheckpointObject(IRInst* value);

IRInst* emitUnreachable();
IRInst* emitMissingReturn();

Expand Down
Loading
Loading