Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UseGraph::replace #6395

Merged
merged 11 commits into from
Feb 25, 2025
84 changes: 36 additions & 48 deletions source/slang/slang-ir-autodiff-primal-hoist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,11 @@ static int getInstRegionNestLevel(

struct UseChain
{
// The chain of uses from the base use to the relevant use.
// However, this is stored in reverse order (so that the last use is the 'base use')
//
List<IRUse*> chain;

static List<UseChain> from(
IRUse* baseUse,
Func<bool, IRUse*> isRelevantUse,
Expand Down Expand Up @@ -1366,41 +1370,20 @@ struct UseChain
return result;
}

void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst)
// This function only replaces the inner links, not the base use.
void replaceInnerLinks(IROutOfOrderCloneContext* ctx, IRBuilder* builder)
{
SLANG_ASSERT(chain.getCount() > 0);

// Simple case: if there is only one use, then we can just replace it.
if (chain.getCount() == 1)
{
builder->replaceOperand(chain.getLast(), inst);
chain.clear();
return;
}

// Pop the last use, which is the base use that needs to be replaced.
auto baseUse = chain.getLast();
chain.removeLast();
const UIndex count = chain.getCount();

// Ensure that replacement inst is set as mapping for the baseUse.
ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst;

IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);

chain.reverse();
chain.removeLast();

// Clone the rest of the chain.
for (auto& use : chain)
// Process the chain in reverse order (excluding the first and last elements).
// That is, iterate from count - 2 down to 1 (inclusive).
for (int i = ((int)count) - 2; i >= 1; i--)
{
ctx->cloneInstOutOfOrder(&chainBuilder, use->get());
IRUse* use = chain[i];
ctx->cloneInstOutOfOrder(builder, use->get());
}

// We won't actually replace the final use, because if there are multiple chains
// it can cause problems. The parent UseGraph will handle that.

chain.clear();
}

IRInst* getUser() const
Expand All @@ -1417,6 +1400,14 @@ struct UseGraph
//
OrderedDictionary<IRUse*, List<UseChain>> chainSets;

// Create a UseGraph from a base inst.
//
// `isRelevantUse` is a predicate that determines if a use is relevant. Traversal will stop at
// this use, and all chains to this use will be grouped together.
//
// `passthroughInst` is a predicate that determines if an inst should be looked through
// for uses.
//
static UseGraph from(
IRInst* baseInst,
Func<bool, IRUse*> isRelevantUse,
Expand Down Expand Up @@ -1445,36 +1436,33 @@ struct UseGraph
return result;
}

void replace(IRBuilder* builder, IRUse* use, IRInst* inst)
void replace(IRBuilder* builder, IRUse* relevantUse, IRInst* inst)
{
// Since we may have common nodes, we will use an out-of-order cloning context
// that can retroactively correct the uses as needed.
//
IROutOfOrderCloneContext ctx;
List<UseChain> chains = chainSets[use];
for (auto chain : chains)
{
chain.replace(&ctx, builder, inst);
}
List<UseChain> chains = chainSets[relevantUse];

if (!isTrivial())
// Link the first use of each chain to inst.
for (auto& chain : chains)
ctx.cloneEnv.mapOldValToNew[chain.chain.getLast()->get()] = inst;

// Process the inner links of each chain using the replacement.
for (auto& chain : chains)
{
builder->setInsertBefore(use->getUser());
auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get());
IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);

// Replace the base use.
builder->replaceOperand(use, lastInstInChain);
chain.replaceInnerLinks(&ctx, builder);
}
}

bool isTrivial()
{
// We're trivial if there's only one chain, and it has only one use.
if (chainSets.getCount() != 1)
return false;
// Finally, replace the relevant use (i.e, "final use") with the new replacement inst.
builder->setInsertBefore(relevantUse->getUser());
auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, relevantUse->get());

auto& chain = chainSets.getFirst().value;
return chain.getCount() == 1;
// Replace the base use.
builder->replaceOperand(relevantUse, lastInstInChain);
}

List<IRUse*> getUniqueUses() const
Expand Down
43 changes: 43 additions & 0 deletions tests/autodiff/dynamic-dispatch-ptr.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type

//CHECK: 1.0

//TEST_INPUT: type_conformance Sensor:ISensor = 1;

[anyValueSize(16)]
interface ISensor
{
[Differentiable]
float4 splat(float4 point);
}

struct Sensor : ISensor
{
[Differentiable]
float4 splat(float4 point)
{
return point;
}
}

[Differentiable]
float4 splat(ISensor* obj, float4 point)
{
return obj->splat(point);
}

//TEST_INPUT: set s = ubuffer(data=[0 0 1 0 0 0 0 0])
uniform ISensor *s;

//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float4> outBuffer;

[shader("compute"), numthreads(1, 1, 1)]
void computeMain(
uint3 id : SV_DispatchThreadID
)
{
DifferentialPair<float4> dp;
__bwd_diff(splat)(s, dp, float4(1.0f));
outBuffer[id.x] = dp.d;
}
Loading